|
|
|
import time |
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from typing import Iterable, Optional |
|
|
|
from funasr.register import tables |
|
from funasr.models.ctc.ctc import CTC |
|
from funasr.utils.datadir_writer import DatadirWriter |
|
from funasr.models.paraformer.search import Hypothesis |
|
from funasr.train_utils.device_funcs import force_gatherable |
|
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
|
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy |
|
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
|
|
|
|
|
def ctc_forced_align( |
|
log_probs: torch.Tensor, |
|
targets: torch.Tensor, |
|
input_lengths: torch.Tensor, |
|
target_lengths: torch.Tensor, |
|
blank: int = 0, |
|
ignore_id: int = -1, |
|
) -> torch.Tensor: |
|
"""Align a CTC label sequence to an emission. |
|
|
|
Args: |
|
log_probs (Tensor): log probability of CTC emission output. |
|
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length, |
|
`C` is the number of characters in alphabet including blank. |
|
targets (Tensor): Target sequence. Tensor of shape `(B, L)`, |
|
where `L` is the target length. |
|
input_lengths (Tensor): |
|
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`. |
|
target_lengths (Tensor): |
|
Lengths of the targets. 1-D Tensor of shape `(B,)`. |
|
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0) |
|
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1) |
|
""" |
|
targets[targets == ignore_id] = blank |
|
|
|
batch_size, input_time_size, _ = log_probs.size() |
|
bsz_indices = torch.arange(batch_size, device=input_lengths.device) |
|
|
|
_t_a_r_g_e_t_s_ = torch.cat( |
|
( |
|
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1), |
|
torch.full_like(targets[:, :1], blank), |
|
), |
|
dim=-1, |
|
) |
|
diff_labels = torch.cat( |
|
( |
|
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1), |
|
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2], |
|
), |
|
dim=1, |
|
) |
|
|
|
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype) |
|
padding_num = 2 |
|
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1) |
|
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype) |
|
best_score[:, padding_num + 0] = log_probs[:, 0, blank] |
|
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]] |
|
|
|
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype) |
|
|
|
for t in range(1, input_time_size): |
|
prev = torch.stack( |
|
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf)) |
|
) |
|
prev_max_value, prev_max_idx = prev.max(dim=0) |
|
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value |
|
backpointers[:, t, padding_num:] = prev_max_idx |
|
|
|
l1l2 = best_score.gather( |
|
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1) |
|
) |
|
|
|
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long) |
|
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1) |
|
|
|
for t in range(input_time_size - 1, 0, -1): |
|
target_indices = path[:, t] |
|
prev_max_idx = backpointers[bsz_indices, t, target_indices] |
|
path[:, t - 1] += target_indices - prev_max_idx |
|
|
|
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0)) |
|
return alignments |
|
|
|
class SinusoidalPositionEncoder(torch.nn.Module): |
|
""" """ |
|
|
|
def __int__(self, d_model=80, dropout_rate=0.1): |
|
pass |
|
|
|
def encode( |
|
self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32 |
|
): |
|
batch_size = positions.size(0) |
|
positions = positions.type(dtype) |
|
device = positions.device |
|
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / ( |
|
depth / 2 - 1 |
|
) |
|
inv_timescales = torch.exp( |
|
torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment) |
|
) |
|
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) |
|
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( |
|
inv_timescales, [1, 1, -1] |
|
) |
|
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) |
|
return encoding.type(dtype) |
|
|
|
def forward(self, x): |
|
batch_size, timesteps, input_dim = x.size() |
|
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :] |
|
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
|
|
|
return x + position_encoding |
|
|
|
|
|
class PositionwiseFeedForward(torch.nn.Module): |
|
"""Positionwise feed forward layer. |
|
|
|
Args: |
|
idim (int): Input dimenstion. |
|
hidden_units (int): The number of hidden units. |
|
dropout_rate (float): Dropout rate. |
|
|
|
""" |
|
|
|
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): |
|
"""Construct an PositionwiseFeedForward object.""" |
|
super(PositionwiseFeedForward, self).__init__() |
|
self.w_1 = torch.nn.Linear(idim, hidden_units) |
|
self.w_2 = torch.nn.Linear(hidden_units, idim) |
|
self.dropout = torch.nn.Dropout(dropout_rate) |
|
self.activation = activation |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
return self.w_2(self.dropout(self.activation(self.w_1(x)))) |
|
|
|
|
|
class MultiHeadedAttentionSANM(nn.Module): |
|
"""Multi-Head Attention layer. |
|
|
|
Args: |
|
n_head (int): The number of heads. |
|
n_feat (int): The number of features. |
|
dropout_rate (float): Dropout rate. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
n_head, |
|
in_feat, |
|
n_feat, |
|
dropout_rate, |
|
kernel_size, |
|
sanm_shfit=0, |
|
lora_list=None, |
|
lora_rank=8, |
|
lora_alpha=16, |
|
lora_dropout=0.1, |
|
): |
|
"""Construct an MultiHeadedAttention object.""" |
|
super().__init__() |
|
assert n_feat % n_head == 0 |
|
|
|
self.d_k = n_feat // n_head |
|
self.h = n_head |
|
|
|
|
|
|
|
|
|
self.linear_out = nn.Linear(n_feat, n_feat) |
|
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3) |
|
self.attn = None |
|
self.dropout = nn.Dropout(p=dropout_rate) |
|
|
|
self.fsmn_block = nn.Conv1d( |
|
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False |
|
) |
|
|
|
left_padding = (kernel_size - 1) // 2 |
|
if sanm_shfit > 0: |
|
left_padding = left_padding + sanm_shfit |
|
right_padding = kernel_size - 1 - left_padding |
|
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) |
|
|
|
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None): |
|
b, t, d = inputs.size() |
|
if mask is not None: |
|
mask = torch.reshape(mask, (b, -1, 1)) |
|
if mask_shfit_chunk is not None: |
|
mask = mask * mask_shfit_chunk |
|
inputs = inputs * mask |
|
|
|
x = inputs.transpose(1, 2) |
|
x = self.pad_fn(x) |
|
x = self.fsmn_block(x) |
|
x = x.transpose(1, 2) |
|
x += inputs |
|
x = self.dropout(x) |
|
if mask is not None: |
|
x = x * mask |
|
return x |
|
|
|
def forward_qkv(self, x): |
|
"""Transform query, key and value. |
|
|
|
Args: |
|
query (torch.Tensor): Query tensor (#batch, time1, size). |
|
key (torch.Tensor): Key tensor (#batch, time2, size). |
|
value (torch.Tensor): Value tensor (#batch, time2, size). |
|
|
|
Returns: |
|
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). |
|
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). |
|
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). |
|
|
|
""" |
|
b, t, d = x.size() |
|
q_k_v = self.linear_q_k_v(x) |
|
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) |
|
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose( |
|
1, 2 |
|
) |
|
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose( |
|
1, 2 |
|
) |
|
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose( |
|
1, 2 |
|
) |
|
|
|
return q_h, k_h, v_h, v |
|
|
|
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None): |
|
"""Compute attention context vector. |
|
|
|
Args: |
|
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). |
|
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). |
|
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). |
|
|
|
Returns: |
|
torch.Tensor: Transformed value (#batch, time1, d_model) |
|
weighted by the attention score (#batch, time1, time2). |
|
|
|
""" |
|
n_batch = value.size(0) |
|
if mask is not None: |
|
if mask_att_chunk_encoder is not None: |
|
mask = mask * mask_att_chunk_encoder |
|
|
|
mask = mask.unsqueeze(1).eq(0) |
|
|
|
min_value = -float( |
|
"inf" |
|
) |
|
scores = scores.masked_fill(mask, min_value) |
|
attn = torch.softmax(scores, dim=-1).masked_fill( |
|
mask, 0.0 |
|
) |
|
else: |
|
attn = torch.softmax(scores, dim=-1) |
|
|
|
p_attn = self.dropout(attn) |
|
x = torch.matmul(p_attn, value) |
|
x = ( |
|
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
|
) |
|
|
|
return self.linear_out(x) |
|
|
|
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
|
"""Compute scaled dot product attention. |
|
|
|
Args: |
|
query (torch.Tensor): Query tensor (#batch, time1, size). |
|
key (torch.Tensor): Key tensor (#batch, time2, size). |
|
value (torch.Tensor): Value tensor (#batch, time2, size). |
|
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
|
(#batch, time1, time2). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor (#batch, time1, d_model). |
|
|
|
""" |
|
q_h, k_h, v_h, v = self.forward_qkv(x) |
|
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) |
|
q_h = q_h * self.d_k ** (-0.5) |
|
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
|
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
|
return att_outs + fsmn_memory |
|
|
|
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): |
|
"""Compute scaled dot product attention. |
|
|
|
Args: |
|
query (torch.Tensor): Query tensor (#batch, time1, size). |
|
key (torch.Tensor): Key tensor (#batch, time2, size). |
|
value (torch.Tensor): Value tensor (#batch, time2, size). |
|
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
|
(#batch, time1, time2). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor (#batch, time1, d_model). |
|
|
|
""" |
|
q_h, k_h, v_h, v = self.forward_qkv(x) |
|
if chunk_size is not None and look_back > 0 or look_back == -1: |
|
if cache is not None: |
|
k_h_stride = k_h[:, :, : -(chunk_size[2]), :] |
|
v_h_stride = v_h[:, :, : -(chunk_size[2]), :] |
|
k_h = torch.cat((cache["k"], k_h), dim=2) |
|
v_h = torch.cat((cache["v"], v_h), dim=2) |
|
|
|
cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2) |
|
cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2) |
|
if look_back != -1: |
|
cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :] |
|
cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :] |
|
else: |
|
cache_tmp = { |
|
"k": k_h[:, :, : -(chunk_size[2]), :], |
|
"v": v_h[:, :, : -(chunk_size[2]), :], |
|
} |
|
cache = cache_tmp |
|
fsmn_memory = self.forward_fsmn(v, None) |
|
q_h = q_h * self.d_k ** (-0.5) |
|
scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
|
att_outs = self.forward_attention(v_h, scores, None) |
|
return att_outs + fsmn_memory, cache |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, input): |
|
output = F.layer_norm( |
|
input.float(), |
|
self.normalized_shape, |
|
self.weight.float() if self.weight is not None else None, |
|
self.bias.float() if self.bias is not None else None, |
|
self.eps, |
|
) |
|
return output.type_as(input) |
|
|
|
|
|
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): |
|
if maxlen is None: |
|
maxlen = lengths.max() |
|
row_vector = torch.arange(0, maxlen, 1).to(lengths.device) |
|
matrix = torch.unsqueeze(lengths, dim=-1) |
|
mask = row_vector < matrix |
|
mask = mask.detach() |
|
|
|
return mask.to(dtype).to(device) if device is not None else mask.to(dtype) |
|
|
|
|
|
|
|
class EncoderLayerSANM(nn.Module): |
|
def __init__( |
|
self, |
|
in_size, |
|
size, |
|
self_attn, |
|
feed_forward, |
|
dropout_rate, |
|
normalize_before=True, |
|
concat_after=False, |
|
stochastic_depth_rate=0.0, |
|
): |
|
"""Construct an EncoderLayer object.""" |
|
super(EncoderLayerSANM, self).__init__() |
|
self.self_attn = self_attn |
|
self.feed_forward = feed_forward |
|
self.norm1 = LayerNorm(in_size) |
|
self.norm2 = LayerNorm(size) |
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.in_size = in_size |
|
self.size = size |
|
self.normalize_before = normalize_before |
|
self.concat_after = concat_after |
|
if self.concat_after: |
|
self.concat_linear = nn.Linear(size + size, size) |
|
self.stochastic_depth_rate = stochastic_depth_rate |
|
self.dropout_rate = dropout_rate |
|
|
|
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None): |
|
"""Compute encoded features. |
|
|
|
Args: |
|
x_input (torch.Tensor): Input tensor (#batch, time, size). |
|
mask (torch.Tensor): Mask tensor for the input (#batch, time). |
|
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor (#batch, time, size). |
|
torch.Tensor: Mask tensor (#batch, time). |
|
|
|
""" |
|
skip_layer = False |
|
|
|
|
|
stoch_layer_coeff = 1.0 |
|
if self.training and self.stochastic_depth_rate > 0: |
|
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate |
|
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) |
|
|
|
if skip_layer: |
|
if cache is not None: |
|
x = torch.cat([cache, x], dim=1) |
|
return x, mask |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.norm1(x) |
|
|
|
if self.concat_after: |
|
x_concat = torch.cat( |
|
( |
|
x, |
|
self.self_attn( |
|
x, |
|
mask, |
|
mask_shfit_chunk=mask_shfit_chunk, |
|
mask_att_chunk_encoder=mask_att_chunk_encoder, |
|
), |
|
), |
|
dim=-1, |
|
) |
|
if self.in_size == self.size: |
|
x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
|
else: |
|
x = stoch_layer_coeff * self.concat_linear(x_concat) |
|
else: |
|
if self.in_size == self.size: |
|
x = residual + stoch_layer_coeff * self.dropout( |
|
self.self_attn( |
|
x, |
|
mask, |
|
mask_shfit_chunk=mask_shfit_chunk, |
|
mask_att_chunk_encoder=mask_att_chunk_encoder, |
|
) |
|
) |
|
else: |
|
x = stoch_layer_coeff * self.dropout( |
|
self.self_attn( |
|
x, |
|
mask, |
|
mask_shfit_chunk=mask_shfit_chunk, |
|
mask_att_chunk_encoder=mask_att_chunk_encoder, |
|
) |
|
) |
|
if not self.normalize_before: |
|
x = self.norm1(x) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.norm2(x) |
|
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) |
|
if not self.normalize_before: |
|
x = self.norm2(x) |
|
|
|
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder |
|
|
|
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): |
|
"""Compute encoded features. |
|
|
|
Args: |
|
x_input (torch.Tensor): Input tensor (#batch, time, size). |
|
mask (torch.Tensor): Mask tensor for the input (#batch, time). |
|
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
|
|
|
Returns: |
|
torch.Tensor: Output tensor (#batch, time, size). |
|
torch.Tensor: Mask tensor (#batch, time). |
|
|
|
""" |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.norm1(x) |
|
|
|
if self.in_size == self.size: |
|
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) |
|
x = residual + attn |
|
else: |
|
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) |
|
|
|
if not self.normalize_before: |
|
x = self.norm1(x) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.norm2(x) |
|
x = residual + self.feed_forward(x) |
|
if not self.normalize_before: |
|
x = self.norm2(x) |
|
|
|
return x, cache |
|
|
|
|
|
@tables.register("encoder_classes", "SenseVoiceEncoderSmall") |
|
class SenseVoiceEncoderSmall(nn.Module): |
|
""" |
|
Author: Speech Lab of DAMO Academy, Alibaba Group |
|
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition |
|
https://arxiv.org/abs/2006.01713 |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
output_size: int = 256, |
|
attention_heads: int = 4, |
|
linear_units: int = 2048, |
|
num_blocks: int = 6, |
|
tp_blocks: int = 0, |
|
dropout_rate: float = 0.1, |
|
positional_dropout_rate: float = 0.1, |
|
attention_dropout_rate: float = 0.0, |
|
stochastic_depth_rate: float = 0.0, |
|
input_layer: Optional[str] = "conv2d", |
|
pos_enc_class=SinusoidalPositionEncoder, |
|
normalize_before: bool = True, |
|
concat_after: bool = False, |
|
positionwise_layer_type: str = "linear", |
|
positionwise_conv_kernel_size: int = 1, |
|
padding_idx: int = -1, |
|
kernel_size: int = 11, |
|
sanm_shfit: int = 0, |
|
selfattention_layer_type: str = "sanm", |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self._output_size = output_size |
|
|
|
self.embed = SinusoidalPositionEncoder() |
|
|
|
self.normalize_before = normalize_before |
|
|
|
positionwise_layer = PositionwiseFeedForward |
|
positionwise_layer_args = ( |
|
output_size, |
|
linear_units, |
|
dropout_rate, |
|
) |
|
|
|
encoder_selfattn_layer = MultiHeadedAttentionSANM |
|
encoder_selfattn_layer_args0 = ( |
|
attention_heads, |
|
input_size, |
|
output_size, |
|
attention_dropout_rate, |
|
kernel_size, |
|
sanm_shfit, |
|
) |
|
encoder_selfattn_layer_args = ( |
|
attention_heads, |
|
output_size, |
|
output_size, |
|
attention_dropout_rate, |
|
kernel_size, |
|
sanm_shfit, |
|
) |
|
|
|
self.encoders0 = nn.ModuleList( |
|
[ |
|
EncoderLayerSANM( |
|
input_size, |
|
output_size, |
|
encoder_selfattn_layer(*encoder_selfattn_layer_args0), |
|
positionwise_layer(*positionwise_layer_args), |
|
dropout_rate, |
|
) |
|
for i in range(1) |
|
] |
|
) |
|
self.encoders = nn.ModuleList( |
|
[ |
|
EncoderLayerSANM( |
|
output_size, |
|
output_size, |
|
encoder_selfattn_layer(*encoder_selfattn_layer_args), |
|
positionwise_layer(*positionwise_layer_args), |
|
dropout_rate, |
|
) |
|
for i in range(num_blocks - 1) |
|
] |
|
) |
|
|
|
self.tp_encoders = nn.ModuleList( |
|
[ |
|
EncoderLayerSANM( |
|
output_size, |
|
output_size, |
|
encoder_selfattn_layer(*encoder_selfattn_layer_args), |
|
positionwise_layer(*positionwise_layer_args), |
|
dropout_rate, |
|
) |
|
for i in range(tp_blocks) |
|
] |
|
) |
|
|
|
self.after_norm = LayerNorm(output_size) |
|
|
|
self.tp_norm = LayerNorm(output_size) |
|
|
|
def output_size(self) -> int: |
|
return self._output_size |
|
|
|
def forward( |
|
self, |
|
xs_pad: torch.Tensor, |
|
ilens: torch.Tensor, |
|
): |
|
"""Embed positions in tensor.""" |
|
masks = sequence_mask(ilens, dtype=torch.bfloat16, device=ilens.device)[:, None, :] |
|
|
|
|
|
|
|
|
|
xs_pad *= self.output_size() ** 0.5 |
|
|
|
xs_pad = self.embed(xs_pad) |
|
|
|
|
|
for layer_idx, encoder_layer in enumerate(self.encoders0): |
|
encoder_outs = encoder_layer(xs_pad, masks) |
|
xs_pad, masks = encoder_outs[0], encoder_outs[1] |
|
|
|
for layer_idx, encoder_layer in enumerate(self.encoders): |
|
encoder_outs = encoder_layer(xs_pad, masks) |
|
xs_pad, masks = encoder_outs[0], encoder_outs[1] |
|
|
|
xs_pad = self.after_norm(xs_pad) |
|
|
|
|
|
|
|
olens = (masks > 0.5).squeeze(1).sum(1).int() |
|
|
|
for layer_idx, encoder_layer in enumerate(self.tp_encoders): |
|
encoder_outs = encoder_layer(xs_pad, masks) |
|
xs_pad, masks = encoder_outs[0], encoder_outs[1] |
|
|
|
xs_pad = self.tp_norm(xs_pad) |
|
return xs_pad, olens |
|
|
|
|
|
@tables.register("model_classes", "SenseVoiceSmall") |
|
class SenseVoiceSmall(nn.Module): |
|
"""CTC-attention hybrid Encoder-Decoder model""" |
|
|
|
def __init__( |
|
self, |
|
specaug: str = None, |
|
specaug_conf: dict = None, |
|
normalize: str = None, |
|
normalize_conf: dict = None, |
|
encoder: str = None, |
|
encoder_conf: dict = None, |
|
ctc_conf: dict = None, |
|
input_size: int = 80, |
|
vocab_size: int = -1, |
|
ignore_id: int = -1, |
|
blank_id: int = 0, |
|
sos: int = 1, |
|
eos: int = 2, |
|
length_normalized_loss: bool = False, |
|
**kwargs, |
|
): |
|
|
|
super().__init__() |
|
|
|
if specaug is not None: |
|
specaug_class = tables.specaug_classes.get(specaug) |
|
specaug = specaug_class(**specaug_conf) |
|
if normalize is not None: |
|
normalize_class = tables.normalize_classes.get(normalize) |
|
normalize = normalize_class(**normalize_conf) |
|
encoder_class = tables.encoder_classes.get(encoder) |
|
encoder = encoder_class(input_size=input_size, **encoder_conf) |
|
encoder_output_size = encoder.output_size() |
|
|
|
if ctc_conf is None: |
|
ctc_conf = {} |
|
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) |
|
|
|
self.blank_id = blank_id |
|
self.sos = sos if sos is not None else vocab_size - 1 |
|
self.eos = eos if eos is not None else vocab_size - 1 |
|
self.vocab_size = vocab_size |
|
self.ignore_id = ignore_id |
|
self.specaug = specaug |
|
self.normalize = normalize |
|
self.encoder = encoder |
|
self.error_calculator = None |
|
|
|
self.ctc = ctc |
|
|
|
self.length_normalized_loss = length_normalized_loss |
|
self.encoder_output_size = encoder_output_size |
|
|
|
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} |
|
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} |
|
self.textnorm_dict = {"withitn": 14, "woitn": 15} |
|
self.textnorm_int_dict = {25016: 14, 25017: 15} |
|
self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size) |
|
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004} |
|
|
|
self.criterion_att = LabelSmoothingLoss( |
|
size=self.vocab_size, |
|
padding_idx=self.ignore_id, |
|
smoothing=kwargs.get("lsm_weight", 0.0), |
|
normalize_length=self.length_normalized_loss, |
|
) |
|
|
|
@staticmethod |
|
def from_pretrained(model:str=None, **kwargs): |
|
from funasr import AutoModel |
|
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs) |
|
|
|
return model, kwargs |
|
|
|
def forward( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
text: torch.Tensor, |
|
text_lengths: torch.Tensor, |
|
**kwargs, |
|
): |
|
"""Encoder + Decoder + Calc loss |
|
Args: |
|
speech: (Batch, Length, ...) |
|
speech_lengths: (Batch, ) |
|
text: (Batch, Length) |
|
text_lengths: (Batch,) |
|
""" |
|
|
|
|
|
if len(text_lengths.size()) > 1: |
|
text_lengths = text_lengths[:, 0] |
|
if len(speech_lengths.size()) > 1: |
|
speech_lengths = speech_lengths[:, 0] |
|
|
|
batch_size = speech.shape[0] |
|
|
|
|
|
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text) |
|
|
|
loss_ctc, cer_ctc = None, None |
|
loss_rich, acc_rich = None, None |
|
stats = dict() |
|
|
|
loss_ctc, cer_ctc = self._calc_ctc_loss( |
|
encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4 |
|
) |
|
|
|
loss_rich, acc_rich = self._calc_rich_ce_loss( |
|
encoder_out[:, :4, :], text[:, :4] |
|
) |
|
|
|
loss = loss_ctc + loss_rich |
|
|
|
stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None |
|
stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None |
|
stats["loss"] = torch.clone(loss.detach()) if loss is not None else None |
|
stats["acc_rich"] = acc_rich |
|
|
|
|
|
if self.length_normalized_loss: |
|
batch_size = int((text_lengths + 1).sum()) |
|
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
|
return loss, stats, weight |
|
|
|
def encode( |
|
self, |
|
speech: torch.Tensor, |
|
speech_lengths: torch.Tensor, |
|
text: torch.Tensor, |
|
**kwargs, |
|
): |
|
"""Frontend + Encoder. Note that this method is used by asr_inference.py |
|
Args: |
|
speech: (Batch, Length, ...) |
|
speech_lengths: (Batch, ) |
|
ind: int |
|
""" |
|
|
|
|
|
if self.specaug is not None and self.training: |
|
speech, speech_lengths = self.specaug(speech, speech_lengths) |
|
|
|
|
|
if self.normalize is not None: |
|
speech, speech_lengths = self.normalize(speech, speech_lengths) |
|
|
|
|
|
lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device) |
|
language_query = self.embed(lids) |
|
|
|
styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device) |
|
style_query = self.embed(styles) |
|
speech = torch.cat((style_query, speech), dim=1) |
|
speech_lengths += 1 |
|
|
|
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1) |
|
input_query = torch.cat((language_query, event_emo_query), dim=1) |
|
speech = torch.cat((input_query, speech), dim=1) |
|
speech_lengths += 3 |
|
|
|
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) |
|
|
|
return encoder_out, encoder_out_lens |
|
|
|
def _calc_ctc_loss( |
|
self, |
|
encoder_out: torch.Tensor, |
|
encoder_out_lens: torch.Tensor, |
|
ys_pad: torch.Tensor, |
|
ys_pad_lens: torch.Tensor, |
|
): |
|
|
|
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
|
|
|
|
|
cer_ctc = None |
|
if not self.training and self.error_calculator is not None: |
|
ys_hat = self.ctc.argmax(encoder_out).data |
|
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
|
return loss_ctc, cer_ctc |
|
|
|
def _calc_rich_ce_loss( |
|
self, |
|
encoder_out: torch.Tensor, |
|
ys_pad: torch.Tensor, |
|
): |
|
decoder_out = self.ctc.ctc_lo(encoder_out) |
|
|
|
loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous()) |
|
acc_rich = th_accuracy( |
|
decoder_out.view(-1, self.vocab_size), |
|
ys_pad.contiguous(), |
|
ignore_label=self.ignore_id, |
|
) |
|
|
|
return loss_rich, acc_rich |
|
|
|
|
|
def inference( |
|
self, |
|
data_in, |
|
data_lengths=None, |
|
key: list = ["wav_file_tmp_name"], |
|
tokenizer=None, |
|
frontend=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
meta_data = {} |
|
if ( |
|
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
|
): |
|
speech, speech_lengths = data_in, data_lengths |
|
if len(speech.shape) < 3: |
|
speech = speech[None, :, :] |
|
if speech_lengths is None: |
|
speech_lengths = speech.shape[1] |
|
else: |
|
|
|
time1 = time.perf_counter() |
|
audio_sample_list = load_audio_text_image_video( |
|
data_in, |
|
fs=frontend.fs, |
|
audio_fs=kwargs.get("fs", 16000), |
|
data_type=kwargs.get("data_type", "sound"), |
|
tokenizer=tokenizer, |
|
) |
|
time2 = time.perf_counter() |
|
meta_data["load_data"] = f"{time2 - time1:0.3f}" |
|
speech, speech_lengths = extract_fbank( |
|
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
|
) |
|
time3 = time.perf_counter() |
|
meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
|
meta_data["batch_data_time"] = ( |
|
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
|
) |
|
|
|
speech = speech.to(device=kwargs["device"]) |
|
speech_lengths = speech_lengths.to(device=kwargs["device"]) |
|
|
|
language = kwargs.get("language", "auto") |
|
language_query = self.embed( |
|
torch.LongTensor( |
|
[[self.lid_dict[language] if language in self.lid_dict else 0]] |
|
).to(speech.device) |
|
).repeat(speech.size(0), 1, 1) |
|
|
|
use_itn = kwargs.get("use_itn", False) |
|
output_timestamp = kwargs.get("output_timestamp", False) |
|
|
|
textnorm = kwargs.get("text_norm", None) |
|
if textnorm is None: |
|
textnorm = "withitn" if use_itn else "woitn" |
|
textnorm_query = self.embed( |
|
torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) |
|
).repeat(speech.size(0), 1, 1) |
|
speech = torch.cat((textnorm_query, speech), dim=1) |
|
speech_lengths += 1 |
|
|
|
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( |
|
speech.size(0), 1, 1 |
|
) |
|
input_query = torch.cat((language_query, event_emo_query), dim=1) |
|
speech = torch.cat((input_query, speech), dim=1) |
|
speech_lengths += 3 |
|
|
|
|
|
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) |
|
if isinstance(encoder_out, tuple): |
|
encoder_out = encoder_out[0] |
|
|
|
|
|
ctc_logits = self.ctc.log_softmax(encoder_out) |
|
if kwargs.get("ban_emo_unk", False): |
|
ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf") |
|
|
|
results = [] |
|
b, n, d = encoder_out.size() |
|
if isinstance(key[0], (list, tuple)): |
|
key = key[0] |
|
if len(key) < b: |
|
key = key * b |
|
for i in range(b): |
|
x = ctc_logits[i, : encoder_out_lens[i].item(), :] |
|
yseq = x.argmax(dim=-1) |
|
yseq = torch.unique_consecutive(yseq, dim=-1) |
|
|
|
ibest_writer = None |
|
if kwargs.get("output_dir") is not None: |
|
if not hasattr(self, "writer"): |
|
self.writer = DatadirWriter(kwargs.get("output_dir")) |
|
ibest_writer = self.writer[f"1best_recog"] |
|
|
|
mask = yseq != self.blank_id |
|
token_int = yseq[mask].tolist() |
|
|
|
|
|
text = tokenizer.decode(token_int) |
|
if ibest_writer is not None: |
|
ibest_writer["text"][key[i]] = text |
|
|
|
if output_timestamp: |
|
from itertools import groupby |
|
timestamp = [] |
|
tokens = tokenizer.text2tokens(text)[4:] |
|
|
|
logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :] |
|
|
|
pred = logits_speech.argmax(-1).cpu() |
|
logits_speech[pred==self.blank_id, self.blank_id] = 0 |
|
|
|
align = ctc_forced_align( |
|
logits_speech.unsqueeze(0).float(), |
|
torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device), |
|
(encoder_out_lens-4).long(), |
|
torch.tensor(len(token_int)-4).unsqueeze(0).long().to(logits_speech.device), |
|
ignore_id=self.ignore_id, |
|
) |
|
|
|
pred = groupby(align[0, :encoder_out_lens[0]]) |
|
_start = 0 |
|
token_id = 0 |
|
ts_max = encoder_out_lens[i] - 4 |
|
for pred_token, pred_frame in pred: |
|
_end = _start + len(list(pred_frame)) |
|
if pred_token != 0: |
|
ts_left = max((_start*60-30)/1000, 0) |
|
ts_right = min((_end*60-30)/1000, (ts_max*60-30)/1000) |
|
timestamp.append([tokens[token_id], ts_left, ts_right]) |
|
token_id += 1 |
|
_start = _end |
|
|
|
result_i = {"key": key[i], "text": text, "timestamp": timestamp} |
|
results.append(result_i) |
|
else: |
|
result_i = {"key": key[i], "text": text} |
|
results.append(result_i) |
|
return results, meta_data |
|
|
|
|
|
def inference_encode( |
|
self, |
|
data_in, |
|
data_lengths=None, |
|
key: list = ["wav_file_tmp_name"], |
|
**kwargs, |
|
): |
|
|
|
|
|
speech, speech_lengths = data_in, data_lengths |
|
if len(speech.shape) < 3: |
|
speech = speech[None, :, :] |
|
if speech_lengths is None: |
|
speech_lengths = speech.shape[1] |
|
|
|
speech = speech.to(device=kwargs["device"]) |
|
speech_lengths = speech_lengths.to(device=kwargs["device"]) |
|
|
|
language = kwargs.get("language", "auto") |
|
language_query = self.embed( |
|
torch.LongTensor( |
|
[[self.lid_dict[language] if language in self.lid_dict else 0]] |
|
).to(speech.device) |
|
).repeat(speech.size(0), 1, 1) |
|
|
|
use_itn = kwargs.get("use_itn", False) |
|
output_timestamp = kwargs.get("output_timestamp", False) |
|
|
|
textnorm = kwargs.get("text_norm", None) |
|
if textnorm is None: |
|
textnorm = "withitn" if use_itn else "woitn" |
|
textnorm_query = self.embed( |
|
torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) |
|
).repeat(speech.size(0), 1, 1) |
|
speech = torch.cat((textnorm_query, speech), dim=1) |
|
speech_lengths += 1 |
|
|
|
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( |
|
speech.size(0), 1, 1 |
|
) |
|
input_query = torch.cat((language_query, event_emo_query), dim=1) |
|
speech = torch.cat((input_query, speech), dim=1) |
|
speech_lengths += 3 |
|
|
|
|
|
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) |
|
if isinstance(encoder_out, tuple): |
|
encoder_out = encoder_out[0] |
|
|
|
return encoder_out, encoder_out_lens |
|
|
|
def export_rebuild_model(model, **kwargs): |
|
model.device = kwargs.get("device") |
|
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) |
|
model.forward = types.MethodType(export_forward, model) |
|
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) |
|
model.export_input_names = types.MethodType(export_input_names, model) |
|
model.export_output_names = types.MethodType(export_output_names, model) |
|
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) |
|
model.export_name = types.MethodType(export_name, model) |
|
return model |
|
|
|
def export(self, **kwargs): |
|
|
|
|
|
if "max_seq_len" not in kwargs: |
|
kwargs["max_seq_len"] = 512 |
|
models = export_rebuild_model(model=self, **kwargs) |
|
return models |
|
|
|
|
|
class AudioEncoder(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
config, |
|
): |
|
super().__init__() |
|
|
|
|
|
|
|
|
|
if "_name_or_path" in config: |
|
model_dir = config._name_or_path |
|
else: |
|
import os |
|
model_file= os.path.abspath(__file__) |
|
model_dir = os.path.dirname(model_file) |
|
|
|
|
|
self.model, self.kwargs = self.build_model(model=model_dir, trust_remote_code=False,) |
|
|
|
|
|
def forward( |
|
self, |
|
audios, |
|
): |
|
|
|
from torch.nn.utils.rnn import pad_sequence |
|
feats_pad = pad_sequence(audios, batch_first=True, padding_value=0.0) |
|
|
|
feats_lens = torch.as_tensor([len(x) for x in audios]) |
|
|
|
feats_pad = feats_pad.to(torch.bfloat16) |
|
|
|
encoder_out, encoder_out_lens = self.model.inference_encode( |
|
feats_pad, |
|
data_lengths=feats_lens, |
|
language="auto", |
|
use_itn=False, |
|
ban_emo_unk=False, |
|
**self.kwargs, |
|
) |
|
|
|
return encoder_out, encoder_out_lens |
|
|
|
audio_embeds = [] |
|
for x, y in zip(encoder_out, encoder_out_lens): |
|
audio_embeds.append(x[:y, ...]) |
|
|
|
audio_embeds = torch.stack(audio_embeds, dim=0) |
|
|
|
return audio_embeds |
|
|
|
|
|
@staticmethod |
|
def build_model(**kwargs): |
|
from omegaconf import DictConfig, ListConfig |
|
import os |
|
|
|
from funasr.download.download_model_from_hub import download_model |
|
from funasr.train_utils.set_all_random_seed import set_all_random_seed |
|
from funasr.register import tables |
|
from funasr.train_utils.load_pretrained_model import load_pretrained_model |
|
from funasr.utils.misc import deep_update |
|
|
|
import logging |
|
|
|
assert "model" in kwargs |
|
if "model_conf" not in kwargs: |
|
logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) |
|
kwargs = download_model(**kwargs) |
|
|
|
set_all_random_seed(kwargs.get("seed", 0)) |
|
|
|
device = kwargs.get("device", "cuda") |
|
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: |
|
device = "cpu" |
|
kwargs["batch_size"] = 1 |
|
kwargs["device"] = device |
|
|
|
torch.set_num_threads(kwargs.get("ncpu", 4)) |
|
|
|
|
|
tokenizer = kwargs.get("tokenizer", None) |
|
kwargs["tokenizer"] = tokenizer |
|
kwargs["vocab_size"] = -1 |
|
|
|
if tokenizer is not None: |
|
tokenizers = ( |
|
tokenizer.split(",") if isinstance(tokenizer, str) else tokenizer |
|
) |
|
tokenizers_conf = kwargs.get("tokenizer_conf", {}) |
|
tokenizers_build = [] |
|
vocab_sizes = [] |
|
token_lists = [] |
|
|
|
|
|
token_list_files = kwargs.get("token_lists", []) |
|
seg_dicts = kwargs.get("seg_dicts", []) |
|
|
|
|
|
if not isinstance(tokenizers_conf, (list, tuple, ListConfig)): |
|
tokenizers_conf = [tokenizers_conf] * len(tokenizers) |
|
|
|
for i, tokenizer in enumerate(tokenizers): |
|
tokenizer_class = tables.tokenizer_classes.get(tokenizer) |
|
tokenizer_conf = tokenizers_conf[i] |
|
|
|
|
|
if len(token_list_files) > 1: |
|
tokenizer_conf["token_list"] = token_list_files[i] |
|
if len(seg_dicts) > 1: |
|
tokenizer_conf["seg_dict"] = seg_dicts[i] |
|
|
|
|
|
tokenizer = tokenizer_class(**tokenizer_conf) |
|
tokenizers_build.append(tokenizer) |
|
token_list = tokenizer.token_list if hasattr(tokenizer, "token_list") else None |
|
token_list = ( |
|
tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else token_list |
|
) |
|
vocab_size = -1 |
|
if token_list is not None: |
|
vocab_size = len(token_list) |
|
|
|
if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"): |
|
vocab_size = tokenizer.get_vocab_size() |
|
token_lists.append(token_list) |
|
vocab_sizes.append(vocab_size) |
|
|
|
if len(tokenizers_build) <= 1: |
|
tokenizers_build = tokenizers_build[0] |
|
token_lists = token_lists[0] |
|
vocab_sizes = vocab_sizes[0] |
|
|
|
kwargs["tokenizer"] = tokenizers_build |
|
kwargs["vocab_size"] = vocab_sizes |
|
kwargs["token_list"] = token_lists |
|
|
|
|
|
frontend = kwargs.get("frontend", None) |
|
kwargs["input_size"] = None |
|
if frontend is not None: |
|
frontend_class = tables.frontend_classes.get(frontend) |
|
frontend = frontend_class(**kwargs.get("frontend_conf", {})) |
|
kwargs["input_size"] = ( |
|
frontend.output_size() if hasattr(frontend, "output_size") else None |
|
) |
|
kwargs["frontend"] = frontend |
|
|
|
model_class = tables.model_classes.get(kwargs["model"]) |
|
assert model_class is not None, f'{kwargs["model"]} is not registered' |
|
model_conf = {} |
|
deep_update(model_conf, kwargs.get("model_conf", {})) |
|
deep_update(model_conf, kwargs) |
|
model = model_class(**model_conf) |
|
|
|
|
|
init_param = kwargs.get("init_param", None) |
|
if init_param is not None: |
|
if os.path.exists(init_param): |
|
logging.info(f"Loading pretrained params from {init_param}") |
|
load_pretrained_model( |
|
model=model, |
|
path=init_param, |
|
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), |
|
oss_bucket=kwargs.get("oss_bucket", None), |
|
scope_map=kwargs.get("scope_map", []), |
|
excludes=kwargs.get("excludes", None), |
|
) |
|
else: |
|
print(f"error, init_param does not exist!: {init_param}") |
|
|
|
|
|
if kwargs.get("fp16", False): |
|
model.to(torch.float16) |
|
elif kwargs.get("bf16", False): |
|
model.to(torch.bfloat16) |
|
|
|
|
|
if not kwargs.get("disable_log", True): |
|
tables.print() |
|
|
|
return model, kwargs |
|
|
|
|