|
"""Copied from https://github.com/salute-developers/GigaAM/blob/main/gigaam/encoder.py""" |
|
import math |
|
from abc import ABC, abstractmethod |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMPORT_FLASH = False |
|
IMPORT_FLASH_ERR = "Flash Attention not installed." |
|
|
|
|
|
|
|
|
|
def rtt_half(x: Tensor) -> Tensor: |
|
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
|
return torch.cat([-x2, x1], dim=x1.ndim - 1) |
|
|
|
|
|
def apply_rotary_pos_emb( |
|
q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, offset: int = 0 |
|
) -> Tuple[Tensor, Tensor]: |
|
""" |
|
Applies Rotary Position Embeddings to query and key tensors. |
|
""" |
|
cos, sin = ( |
|
cos[offset : q.shape[0] + offset, ...], |
|
sin[offset : q.shape[0] + offset, ...], |
|
) |
|
return (q * cos) + (rtt_half(q) * sin), (k * cos) + (rtt_half(k) * sin) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StridingSubsampling(nn.Module): |
|
""" |
|
Strided Subsampling layer used to reduce the sequence length. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
subsampling_factor: int, |
|
feat_in: int, |
|
feat_out: int, |
|
conv_channels: int, |
|
): |
|
super().__init__() |
|
self._sampling_num = int(math.log(subsampling_factor, 2)) |
|
self._stride = 2 |
|
self._kernel_size = 3 |
|
self._padding = (self._kernel_size - 1) // 2 |
|
|
|
layers: List[nn.Module] = [] |
|
in_channels = 1 |
|
for _ in range(self._sampling_num): |
|
layers.append( |
|
torch.nn.Conv2d( |
|
in_channels=in_channels, |
|
out_channels=conv_channels, |
|
kernel_size=self._kernel_size, |
|
stride=self._stride, |
|
padding=self._padding, |
|
) |
|
) |
|
layers.append(nn.ReLU()) |
|
in_channels = conv_channels |
|
|
|
out_length = self.calc_output_length(torch.tensor(feat_in)) |
|
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) |
|
self.conv = torch.nn.Sequential(*layers) |
|
|
|
def calc_output_length(self, lengths: Tensor) -> Tensor: |
|
""" |
|
Calculates the output length after applying the subsampling. |
|
""" |
|
lengths = lengths.to(torch.float) |
|
add_pad = 2 * self._padding - self._kernel_size |
|
for _ in range(self._sampling_num): |
|
lengths = torch.div(lengths + add_pad, self._stride) + 1.0 |
|
lengths = torch.floor(lengths) |
|
return lengths.to(dtype=torch.int) |
|
|
|
def forward(self, x: Tensor, lengths: Tensor) -> Tuple[Tensor, Tensor]: |
|
x = self.conv(x.unsqueeze(1)) |
|
b, _, t, _ = x.size() |
|
x = self.out(x.transpose(1, 2).reshape(b, t, -1)) |
|
return x, self.calc_output_length(lengths) |
|
|
|
|
|
class MultiHeadAttention(nn.Module, ABC): |
|
""" |
|
Base class of Multi-Head Attention Mechanisms. |
|
""" |
|
|
|
def __init__(self, n_head: int, n_feat: int, flash_attn=False): |
|
super().__init__() |
|
assert n_feat % n_head == 0 |
|
self.d_k = n_feat // n_head |
|
self.h = n_head |
|
self.linear_q = nn.Linear(n_feat, n_feat) |
|
self.linear_k = nn.Linear(n_feat, n_feat) |
|
self.linear_v = nn.Linear(n_feat, n_feat) |
|
self.linear_out = nn.Linear(n_feat, n_feat) |
|
self.flash_attn = flash_attn |
|
if self.flash_attn and not IMPORT_FLASH: |
|
raise RuntimeError( |
|
f"flash_attn_func was imported with err {IMPORT_FLASH_ERR}. " |
|
"Please install flash_attn or use --no_flash flag. " |
|
"If you have already done this, " |
|
"--force-reinstall flag might be useful" |
|
) |
|
|
|
def forward_qkv( |
|
self, query: Tensor, key: Tensor, value: Tensor |
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
""" |
|
Projects the inputs into queries, keys, and values for multi-head attention. |
|
""" |
|
b = query.size(0) |
|
q = self.linear_q(query).view(b, -1, self.h, self.d_k) |
|
k = self.linear_k(key).view(b, -1, self.h, self.d_k) |
|
v = self.linear_v(value).view(b, -1, self.h, self.d_k) |
|
if self.flash_attn: |
|
return q, k, v |
|
return q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) |
|
|
|
def forward_attention( |
|
self, value: Tensor, scores: Tensor, mask: Optional[Tensor] |
|
) -> Tensor: |
|
""" |
|
Computes the scaled dot-product attention given the projected values and scores. |
|
""" |
|
b = value.size(0) |
|
if mask is not None: |
|
mask = mask.unsqueeze(1) |
|
scores = scores.masked_fill(mask, -10000.0) |
|
attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) |
|
else: |
|
attn = torch.softmax(scores, dim=-1) |
|
x = torch.matmul(attn, value) |
|
x = x.transpose(1, 2).reshape(b, -1, self.h * self.d_k) |
|
return self.linear_out(x) |
|
|
|
|
|
class RelPositionMultiHeadAttention(MultiHeadAttention): |
|
""" |
|
Relative Position Multi-Head Attention module. |
|
""" |
|
|
|
def __init__(self, n_head: int, n_feat: int): |
|
super().__init__(n_head, n_feat) |
|
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
|
self.pos_bias_u = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) |
|
self.pos_bias_v = nn.Parameter(torch.FloatTensor(self.h, self.d_k)) |
|
|
|
def rel_shift(self, x: Tensor) -> Tensor: |
|
b, h, qlen, pos_len = x.size() |
|
x = torch.nn.functional.pad(x, pad=(1, 0)) |
|
x = x.view(b, h, -1, qlen) |
|
return x[:, :, 1:].view(b, h, qlen, pos_len) |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
pos_emb: Tensor, |
|
mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
q, k, v = self.forward_qkv(query, key, value) |
|
q = q.transpose(1, 2) |
|
p = self.linear_pos(pos_emb) |
|
p = p.view(pos_emb.shape[0], -1, self.h, self.d_k).transpose(1, 2) |
|
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
|
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
|
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
|
matrix_bd = self.rel_shift(matrix_bd) |
|
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
|
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] |
|
scores = (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
|
return self.forward_attention(v, scores, mask) |
|
|
|
|
|
class RotaryPositionMultiHeadAttention(MultiHeadAttention): |
|
""" |
|
Rotary Position Multi-Head Attention module. |
|
""" |
|
|
|
def forward( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
value: Tensor, |
|
pos_emb: List[Tensor], |
|
mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
b, t, _ = value.size() |
|
query = query.transpose(0, 1).view(t, b, self.h, self.d_k) |
|
key = key.transpose(0, 1).view(t, b, self.h, self.d_k) |
|
value = value.transpose(0, 1).view(t, b, self.h, self.d_k) |
|
|
|
cos, sin = pos_emb |
|
query, key = apply_rotary_pos_emb(query, key, cos, sin, offset=0) |
|
|
|
q, k, v = self.forward_qkv( |
|
query.view(t, b, self.h * self.d_k).transpose(0, 1), |
|
key.view(t, b, self.h * self.d_k).transpose(0, 1), |
|
value.view(t, b, self.h * self.d_k).transpose(0, 1), |
|
) |
|
|
|
|
|
scores = torch.matmul(q, k.transpose(-2, -1) / math.sqrt(self.d_k)) |
|
out = self.forward_attention(v, scores, mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
class PositionalEncoding(nn.Module, ABC): |
|
""" |
|
Base class of Positional Encodings. |
|
""" |
|
|
|
def __init__(self, dim: int, base: int): |
|
super().__init__() |
|
self.dim = dim |
|
self.base = base |
|
|
|
@abstractmethod |
|
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
|
pass |
|
|
|
def extend_pe(self, length: int, device: torch.device): |
|
""" |
|
Extends the positional encoding buffer to process longer sequences. |
|
""" |
|
pe = self.create_pe(length, device) |
|
if pe is None: |
|
return |
|
if hasattr(self, "pe"): |
|
self.pe = pe |
|
else: |
|
self.register_buffer("pe", pe, persistent=False) |
|
|
|
|
|
class RelPositionalEmbedding(PositionalEncoding): |
|
""" |
|
Relative Positional Embedding module. |
|
""" |
|
|
|
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
|
""" |
|
Creates the relative positional encoding matrix. |
|
""" |
|
if hasattr(self, "pe") and self.pe.shape[1] >= 2 * length - 1: |
|
return None |
|
positions = torch.arange(length - 1, -length, -1, device=device).unsqueeze(1) |
|
pos_length = positions.size(0) |
|
pe = torch.zeros(pos_length, self.dim, device=positions.device) |
|
div_term = torch.exp( |
|
torch.arange(0, self.dim, 2, device=pe.device) |
|
* -(math.log(10000.0) / self.dim) |
|
) |
|
pe[:, 0::2] = torch.sin(positions * div_term) |
|
pe[:, 1::2] = torch.cos(positions * div_term) |
|
return pe.unsqueeze(0) |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]: |
|
input_len = x.size(1) |
|
center_pos = self.pe.size(1) // 2 + 1 |
|
start_pos = center_pos - input_len |
|
end_pos = center_pos + input_len - 1 |
|
return x, self.pe[:, start_pos:end_pos] |
|
|
|
|
|
class RotaryPositionalEmbedding(PositionalEncoding): |
|
""" |
|
Rotary Positional Embedding module. |
|
""" |
|
|
|
def create_pe(self, length: int, device: torch.device) -> Optional[Tensor]: |
|
""" |
|
Creates or extends the rotary positional encoding matrix. |
|
""" |
|
if hasattr(self, "pe") and self.pe.size(0) >= 2 * length: |
|
return None |
|
positions = torch.arange(0, length, dtype=torch.float32, device=device) |
|
inv_freq = 1.0 / ( |
|
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
) |
|
t = torch.arange(length, device=positions.device).type_as(inv_freq) |
|
freqs = torch.einsum("i,j->ij", t, inv_freq) |
|
emb = torch.cat((freqs, freqs), dim=-1).to(positions.device) |
|
return torch.cat([emb.cos()[:, None, None, :], emb.sin()[:, None, None, :]]) |
|
|
|
def forward(self, x: torch.Tensor) -> Tuple[Tensor, List[Tensor]]: |
|
cos_emb = self.pe[0 : x.shape[1]] |
|
half_pe = self.pe.shape[0] // 2 |
|
sin_emb = self.pe[half_pe : half_pe + x.shape[1]] |
|
return x, [cos_emb, sin_emb] |
|
|
|
|
|
class ConformerConvolution(nn.Module): |
|
""" |
|
Conformer Convolution module. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
kernel_size: int, |
|
): |
|
super().__init__() |
|
assert (kernel_size - 1) % 2 == 0 |
|
self.pointwise_conv1 = nn.Conv1d(d_model, d_model * 2, kernel_size=1) |
|
self.depthwise_conv = nn.Conv1d( |
|
in_channels=d_model, |
|
out_channels=d_model, |
|
kernel_size=kernel_size, |
|
padding=(kernel_size - 1) // 2, |
|
groups=d_model, |
|
bias=True, |
|
) |
|
self.batch_norm = nn.BatchNorm1d(d_model) |
|
self.activation = nn.SiLU() |
|
self.pointwise_conv2 = nn.Conv1d(d_model, d_model, kernel_size=1) |
|
|
|
def forward(self, x: Tensor, pad_mask: Optional[Tensor] = None) -> Tensor: |
|
x = x.transpose(1, 2) |
|
x = self.pointwise_conv1(x) |
|
x = nn.functional.glu(x, dim=1) |
|
if pad_mask is not None: |
|
x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) |
|
x = self.depthwise_conv(x) |
|
x = self.batch_norm(x) |
|
x = self.activation(x) |
|
x = self.pointwise_conv2(x) |
|
return x.transpose(1, 2) |
|
|
|
|
|
class ConformerFeedForward(nn.Module): |
|
""" |
|
Conformer Feed Forward module. |
|
""" |
|
|
|
def __init__(self, d_model: int, d_ff: int, use_bias=True): |
|
super().__init__() |
|
self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias) |
|
self.activation = nn.SiLU() |
|
self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.linear2(self.activation(self.linear1(x))) |
|
|
|
|
|
class ConformerLayer(nn.Module): |
|
""" |
|
Conformer Layer module. |
|
This module combines several submodules including feed forward networks, |
|
depthwise separable convolution, and multi-head self-attention |
|
to form a single Conformer block. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
d_ff: int, |
|
self_attention_model: str, |
|
n_heads: int = 16, |
|
conv_kernel_size: int = 31, |
|
flash_attn: bool = False, |
|
): |
|
super().__init__() |
|
self.fc_factor = 0.5 |
|
self.norm_feed_forward1 = nn.LayerNorm(d_model) |
|
self.feed_forward1 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) |
|
self.norm_conv = nn.LayerNorm(d_model) |
|
self.conv = ConformerConvolution( |
|
d_model=d_model, |
|
kernel_size=conv_kernel_size, |
|
) |
|
self.norm_self_att = nn.LayerNorm(d_model) |
|
if self_attention_model == "rotary": |
|
self.self_attn: nn.Module = RotaryPositionMultiHeadAttention( |
|
n_head=n_heads, |
|
n_feat=d_model, |
|
flash_attn=flash_attn, |
|
) |
|
else: |
|
assert not flash_attn, "Not supported flash_attn for rel_pos" |
|
self.self_attn = RelPositionMultiHeadAttention( |
|
n_head=n_heads, |
|
n_feat=d_model, |
|
) |
|
self.norm_feed_forward2 = nn.LayerNorm(d_model) |
|
self.feed_forward2 = ConformerFeedForward(d_model=d_model, d_ff=d_ff) |
|
self.norm_out = nn.LayerNorm(d_model) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
pos_emb: Union[Tensor, List[Tensor]], |
|
att_mask: Optional[Tensor] = None, |
|
pad_mask: Optional[Tensor] = None, |
|
) -> Tensor: |
|
residual = x |
|
x = self.norm_feed_forward1(x) |
|
x = self.feed_forward1(x) |
|
residual = residual + x * self.fc_factor |
|
|
|
x = self.norm_self_att(residual) |
|
x = self.self_attn(x, x, x, pos_emb, mask=att_mask) |
|
residual = residual + x |
|
|
|
x = self.norm_conv(residual) |
|
x = self.conv(x, pad_mask=pad_mask) |
|
residual = residual + x |
|
|
|
x = self.norm_feed_forward2(residual) |
|
x = self.feed_forward2(x) |
|
residual = residual + x * self.fc_factor |
|
|
|
x = self.norm_out(residual) |
|
return x |
|
|
|
|
|
class ConformerEncoder(nn.Module): |
|
""" |
|
Conformer Encoder module. |
|
This module encapsulates the entire Conformer encoder architecture, |
|
consisting of a StridingSubsampling layer, positional embeddings, and |
|
a stack of Conformer Layers. |
|
It serves as the main component responsible for processing speech features. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
feat_in: int = 64, |
|
n_layers: int = 16, |
|
d_model: int = 768, |
|
subsampling_factor: int = 4, |
|
ff_expansion_factor: int = 4, |
|
self_attention_model: str = "rotary", |
|
n_heads: int = 16, |
|
pos_emb_max_len: int = 5000, |
|
conv_kernel_size: int = 31, |
|
flash_attn: bool = False, |
|
): |
|
super().__init__() |
|
self.feat_in = feat_in |
|
assert self_attention_model in [ |
|
"rotary", |
|
"rel_pos", |
|
], f"Not supported attn = {self_attention_model}" |
|
|
|
self.pre_encode = StridingSubsampling( |
|
subsampling_factor=subsampling_factor, |
|
feat_in=feat_in, |
|
feat_out=d_model, |
|
conv_channels=d_model, |
|
) |
|
|
|
if self_attention_model == "rotary": |
|
self.pos_enc: nn.Module = RotaryPositionalEmbedding( |
|
d_model // n_heads, pos_emb_max_len |
|
) |
|
else: |
|
self.pos_enc = RelPositionalEmbedding(d_model, pos_emb_max_len) |
|
|
|
self.layers = nn.ModuleList() |
|
for _ in range(n_layers): |
|
layer = ConformerLayer( |
|
d_model=d_model, |
|
d_ff=d_model * ff_expansion_factor, |
|
self_attention_model=self_attention_model, |
|
n_heads=n_heads, |
|
conv_kernel_size=conv_kernel_size, |
|
flash_attn=flash_attn, |
|
) |
|
self.layers.append(layer) |
|
|
|
self.pos_enc.extend_pe(pos_emb_max_len, next(self.parameters()).device) |
|
|
|
def input_example( |
|
self, |
|
batch_size: int = 1, |
|
seqlen: int = 200, |
|
): |
|
device = next(self.parameters()).device |
|
features = torch.zeros(batch_size, self.feat_in, seqlen) |
|
feature_lengths = torch.full([batch_size], features.shape[-1]) |
|
return features.float().to(device), feature_lengths.to(device) |
|
|
|
def input_names(self): |
|
return ["audio_signal", "length"] |
|
|
|
def output_names(self): |
|
return ["encoded", "encoded_len"] |
|
|
|
def dynamic_axes(self): |
|
return { |
|
"audio_signal": {0: "batch_size", 2: "seq_len"}, |
|
"length": {0: "batch_size"}, |
|
"encoded": {0: "batch_size", 1: "seq_len"}, |
|
"encoded_len": {0: "batch_size"}, |
|
} |
|
|
|
def forward(self, audio_signal: Tensor, length: Tensor) -> Tuple[Tensor, Tensor]: |
|
audio_signal, length = self.pre_encode( |
|
x=audio_signal.transpose(1, 2), lengths=length |
|
) |
|
|
|
max_len = audio_signal.size(1) |
|
audio_signal, pos_emb = self.pos_enc(x=audio_signal) |
|
|
|
pad_mask = torch.arange(0, max_len, device=audio_signal.device).expand( |
|
length.size(0), -1 |
|
) < length.unsqueeze(-1) |
|
|
|
att_mask = None |
|
if audio_signal.shape[0] > 1: |
|
att_mask = pad_mask.unsqueeze(1).repeat([1, max_len, 1]) |
|
att_mask = torch.logical_and(att_mask, att_mask.transpose(1, 2)) |
|
att_mask = ~att_mask |
|
|
|
pad_mask = ~pad_mask |
|
|
|
for layer in self.layers: |
|
audio_signal = layer( |
|
x=audio_signal, |
|
pos_emb=pos_emb, |
|
att_mask=att_mask, |
|
pad_mask=pad_mask, |
|
) |
|
|
|
return audio_signal.transpose(1, 2), length |
|
|