|
|
|
|
|
import itertools as it |
|
from typing import List |
|
|
|
import pandas as pd |
|
import torch |
|
from transformers import LogitsProcessor, PreTrainedTokenizer |
|
|
|
|
|
class CTCPrefixScore(object): |
|
"""Compute CTC label sequence scores |
|
|
|
which is based on Algorithm 2 in WATANABE et al. |
|
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," |
|
but extended to efficiently compute the label probabilities for multiple |
|
hypotheses simultaneously |
|
See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based |
|
Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019. |
|
""" |
|
|
|
def __init__(self, x, blank, eos): |
|
self.logzero = -1e10 |
|
self.blank = blank |
|
self.eos = eos |
|
self.input_length = x.shape[1] |
|
self.batch_size = x.shape[0] |
|
self.x = x |
|
self.device = x.device |
|
|
|
|
|
|
|
self.max_num_labels = x.shape[2] |
|
self.r = torch.full((self.batch_size, self.input_length, 2, self.max_num_labels), self.logzero, |
|
device=self.device) |
|
self.xs = torch.full((self.batch_size, self.input_length, self.max_num_labels), self.logzero, |
|
device=self.device) |
|
|
|
def initial_state(self): |
|
"""Obtain an initial CTC state.""" |
|
|
|
r = torch.full((self.batch_size, self.input_length, 2), self.logzero, device=self.device) |
|
r[..., 1] = torch.cumsum(self.x[..., self.blank], dim=1) |
|
s = torch.zeros((self.batch_size, 1), device=self.device) |
|
|
|
return r, s |
|
|
|
def _resize_tensors(self, number_of_current_samples, num_labels): |
|
if self.r.shape[0] != number_of_current_samples: |
|
self.r = self.r[:number_of_current_samples, ...] |
|
self.xs = self.xs[:number_of_current_samples, ...] |
|
|
|
if self.r.shape[3] != num_labels: |
|
self.r = self.r[:, :, :, :num_labels].fill_(self.logzero) |
|
self.xs = self.xs[:, :, :num_labels].fill_(self.logzero) |
|
else: |
|
self.r.fill_(self.logzero) |
|
self.xs.fill_(self.logzero) |
|
|
|
def _initialize_r(self, decoded_len): |
|
mask = (decoded_len == 0) |
|
self.r[mask, 0, 0, :] = self.xs[mask, 0] |
|
|
|
def _compute_log_phi(self, r_sum, cs, last, decoded_len, r_prev): |
|
|
|
log_phi = r_sum[..., None].expand(-1, -1, cs.shape[1]) |
|
|
|
|
|
non_zero_mask = (decoded_len > 0) |
|
label_match_mask = (cs == last.unsqueeze(1)) |
|
|
|
|
|
log_phi = torch.where((non_zero_mask.unsqueeze(1) & label_match_mask)[:, None, :], r_prev[..., 1:2], log_phi) |
|
return log_phi |
|
|
|
def _compute_log_psi(self, decoded_len, log_phi, x_current): |
|
"""This function computes forward probabilities log(r_t^n(h)), log(r_t^b(h)), |
|
and log prefix probabilities log(psi) for all labels in the batch. |
|
|
|
:param decoded_len: tensor of shape (batch_size,) containing the length of the decoded sequence |
|
:param log_phi: tensor of shape (batch_size, input_length, num_labels) containing the forward probabilities |
|
:param x_current: tensor of shape (batch_size, input_length, num_labels) containing the input frame |
|
|
|
:return log_psi: tensor of shape (batch_size,num_labels) containing the log prefix probabilities |
|
""" |
|
B, T, V = log_phi.shape |
|
start = torch.clamp(decoded_len, min=1) |
|
|
|
|
|
log_psi = self.r[torch.arange(B), start - 1, 0, :] |
|
|
|
|
|
mask_t = torch.arange(1, T, device=decoded_len.device).expand(B, T - 1) >= decoded_len.unsqueeze(1) |
|
|
|
|
|
log_psi = torch.logaddexp(log_psi, torch.logsumexp( |
|
torch.where(mask_t.unsqueeze(-1), log_phi[:, :-1] + self.xs[:, 1:], self.logzero), dim=1)) |
|
|
|
start = torch.clamp(decoded_len, 1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for t in range(start.min(), self.input_length): |
|
should_decode = decoded_len <= t |
|
self.r[:, t, 0] = torch.logaddexp(self.r[:, t - 1, 0], |
|
log_phi[:, t - 1]) + self.xs[:, t] |
|
self.r[:, t, 1] = ( |
|
torch.logaddexp(self.r[:, t - 1, 0], self.r[:, t - 1, 1]) + x_current[:, t, self.blank][:, None] |
|
) |
|
if ~should_decode.any(): |
|
self.r[:, t] = torch.where(should_decode.unsqueeze(-1).unsqueeze(-1), self.r[:, t], self.logzero) |
|
|
|
return log_psi |
|
|
|
def _update_log_psi_with_eos(self, log_psi, cs, r_sum): |
|
|
|
eos_mask = (cs == self.eos) |
|
log_psi[eos_mask] = r_sum[:, -1].unsqueeze(1).expand_as(log_psi)[eos_mask] |
|
|
|
|
|
if self.eos != self.blank: |
|
blank_mask = (cs == self.blank) |
|
log_psi[blank_mask] = self.logzero |
|
return log_psi |
|
|
|
def __call__(self, y, cs, decoded_len, samples_to_be_decoded, r_prev): |
|
"""Compute CTC prefix scores for next labels |
|
|
|
:param y : prefix label sequence |
|
:param cs : array of next labels |
|
:param r_prev: previous CTC state |
|
:return ctc_scores, ctc_states |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_labels = cs.shape[1] |
|
number_of_current_samples = cs.shape[0] |
|
self._resize_tensors(number_of_current_samples, num_labels) |
|
|
|
|
|
x_current = self.x[samples_to_be_decoded] |
|
self.xs = torch.gather(x_current, 2, cs.unsqueeze(1).expand(-1, self.input_length, -1)) |
|
|
|
|
|
self._initialize_r(decoded_len) |
|
|
|
|
|
r_sum = torch.logaddexp(r_prev[:, :, 0], r_prev[:, :, 1]) |
|
last = y[:, -1] |
|
|
|
|
|
log_phi = self._compute_log_phi(r_sum, cs, last, decoded_len, r_prev) |
|
|
|
|
|
|
|
log_psi = self._compute_log_psi(decoded_len, log_phi, x_current) |
|
|
|
|
|
log_psi = self._update_log_psi_with_eos(log_psi, cs, r_sum) |
|
|
|
|
|
|
|
return log_psi, self.r |
|
|
|
|
|
class CTCRescorerLogitsProcessor(LogitsProcessor): |
|
def __init__( |
|
self, |
|
encoder_logits: torch.FloatTensor, |
|
encoder_output_lens: torch.Tensor, |
|
blank_token_id: int, |
|
pad_token_id: int, |
|
eos_token_id: int, |
|
bos_token_id: int, |
|
tokenizer: PreTrainedTokenizer, |
|
ctc_margin: int, |
|
ctc_weight: float, |
|
num_beams: int, |
|
debug: bool = False, |
|
ctc_tokens_to_score: int = 500 |
|
): |
|
super().__init__() |
|
same_logits = torch.tensor(list((tokenizer.upper_cased_tokens.items()))) |
|
|
|
logits = torch.nn.functional.log_softmax(encoder_logits, dim=-1) |
|
logits[..., same_logits[:, 1]] = logits[..., same_logits[:, 0]] |
|
|
|
self.logits = logits |
|
|
|
self.ctc_prefix_scorer = CTCPrefixScore( |
|
self.logits, |
|
blank_token_id, |
|
eos_token_id, |
|
) |
|
self.batch_size = logits.shape[0] |
|
self.input_length = logits.shape[1] |
|
self.num_tokens = logits.shape[2] |
|
self.device = logits.device |
|
self.ctc_weight = ctc_weight |
|
self.num_beams = num_beams |
|
self.ctc_state_prev, self.ctc_score_prev = self.ctc_prefix_scorer.initial_state() |
|
self.eos_token_id = eos_token_id |
|
self.bos_token_id = bos_token_id |
|
self.tokenizer = tokenizer |
|
self.pad_token_id = pad_token_id |
|
self.blank_token_id = blank_token_id |
|
self.debug = False |
|
self.first_timestamp_token_id = tokenizer.get_vocab()["<|0.00|>"] |
|
self.tmp_ctc_scores = torch.empty((self.batch_size, self.num_tokens - 1), device=self.device) |
|
self.tmp_ctc_states = torch.empty((self.batch_size, self.num_tokens - 1, self.input_length, 2), |
|
device=self.device) |
|
self.ctc_tokens_to_score = ctc_tokens_to_score |
|
|
|
def analyze_predictions(self, |
|
scores, ctc_scores, next_token_scores, input_ids, k=10): |
|
print("\n" + "#" * 100) |
|
|
|
batch_size = input_ids.shape[0] |
|
|
|
best_att_ids = scores.topk(k=k, dim=1) |
|
ctc_scores[:, self.first_timestamp_token_id:] = self.ctc_prefix_scorer.logzero |
|
best_ctc_ids = ctc_scores.topk(k=k, dim=1) |
|
best_ids = next_token_scores.topk(k=k, dim=1) |
|
|
|
decoded_prefixes = self.tokenizer.batch_decode( |
|
input_ids, decode_with_timestamps=True, skip_special_tokens=False |
|
) |
|
|
|
def prepare_and_decode(best_ids_tensor): |
|
new_tensor = torch.zeros((batch_size, k * 2), dtype=torch.long) |
|
new_tensor[:, 0::2] = best_ids_tensor.indices |
|
new_tensor[:, 1::2] = self.tokenizer.vocab['#'] |
|
|
|
|
|
flat_tensor = new_tensor.view(-1, 2) |
|
decoded = self.tokenizer.batch_decode( |
|
flat_tensor, decode_with_timestamps=True, skip_special_tokens=False |
|
) |
|
|
|
decoded = [(decoded[i * k:(i + 1) * k]) for i in range(batch_size)] |
|
return decoded |
|
|
|
decoded_att = prepare_and_decode(best_att_ids) |
|
decoded_ctc = prepare_and_decode(best_ctc_ids) |
|
decoded_next = prepare_and_decode(best_ids) |
|
|
|
for idx in range(batch_size): |
|
print("-" * 80) |
|
print(f"HYPOTHESIS {idx}") |
|
print("\nPREFIX:") |
|
print(decoded_prefixes[idx]) |
|
|
|
def print_with_pandas(tokens, scores, title): |
|
df = pd.DataFrame([tokens, [f"{s.item():.2f}" for s in scores]]) |
|
df.index = [f"{title}", "Score"] |
|
print(f"\n{title}:") |
|
print(df.to_string(index=True, header=False)) |
|
|
|
print_with_pandas(decoded_att[idx], best_att_ids.values[idx], "ATT_TOKENS") |
|
print_with_pandas(decoded_ctc[idx], best_ctc_ids.values[idx], "CTC_TOKENS") |
|
print_with_pandas(decoded_next[idx], best_ids.values[idx], "NEXT_TOKENS") |
|
|
|
print(f"\nCTC_EOS: {ctc_scores[idx, self.tokenizer.eos_token_id].item():.2f}") |
|
print() |
|
|
|
print("#" * 100) |
|
|
|
def update_state(self, best_ids, beam_idx): |
|
mask = best_ids < self.first_timestamp_token_id |
|
self.ctc_state_prev = torch.where(mask.unsqueeze(-1).unsqueeze(-1), |
|
self.tmp_ctc_states[beam_idx, best_ids], |
|
self.ctc_state_prev[beam_idx]) |
|
self.ctc_score_prev = torch.where(mask.unsqueeze(-1), |
|
self.tmp_ctc_scores[beam_idx, best_ids].unsqueeze(-1), |
|
self.ctc_score_prev[beam_idx]) |
|
|
|
def __call__(self, input_ids_orig: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
input_ids = input_ids_orig.clone() |
|
|
|
|
|
if (input_ids[:, 0] != self.bos_token_id).any(): |
|
input_ids = torch.stack( |
|
[row[(row == self.bos_token_id).nonzero(as_tuple=True)[0].item():] for row in input_ids]) |
|
|
|
|
|
input_prefix_len = len(self.tokenizer.prefix_tokens) |
|
if input_prefix_len > 1: |
|
input_ids = input_ids[:, input_prefix_len - 1:] |
|
|
|
|
|
input_ids[:, 0] = self.blank_token_id |
|
|
|
|
|
decoded_len = torch.logical_and(input_ids <= self.first_timestamp_token_id, |
|
input_ids != self.blank_token_id).sum(dim=1) |
|
mask = torch.logical_and(input_ids[:, -1] >= self.first_timestamp_token_id, |
|
input_ids[:, -1] != self.blank_token_id) |
|
last_non_timestamp_token = torch.gather(input_ids, 1, |
|
torch.logical_or(input_ids < self.first_timestamp_token_id, |
|
input_ids == self.blank_token_id).sum(dim=1, |
|
keepdim=True) - 1) |
|
input_ids[mask, -1] = last_non_timestamp_token[mask, 0] |
|
|
|
|
|
to_be_decoded = input_ids[:, -1] != self.eos_token_id |
|
self.tmp_ctc_scores[:] = self.ctc_prefix_scorer.logzero |
|
|
|
input_ids_local = input_ids[to_be_decoded] |
|
ids_to_score = torch.topk(scores[:, :self.first_timestamp_token_id], k=self.ctc_tokens_to_score).indices |
|
|
|
|
|
is_eos_present = (ids_to_score == self.eos_token_id).any(dim=1) |
|
ids_to_score[~is_eos_present, self.ctc_tokens_to_score - 1] = self.eos_token_id |
|
|
|
decoded_len_local = decoded_len[to_be_decoded] |
|
|
|
ctc_scores_local, ctc_states_local = self.ctc_prefix_scorer(input_ids_local, ids_to_score[to_be_decoded], |
|
decoded_len_local, to_be_decoded, |
|
self.ctc_state_prev[to_be_decoded]) |
|
|
|
|
|
self.tmp_ctc_scores[to_be_decoded] = (self.tmp_ctc_scores[to_be_decoded] |
|
.scatter(1, ids_to_score[to_be_decoded], ctc_scores_local)) |
|
self.tmp_ctc_states[to_be_decoded] = (self.tmp_ctc_states[to_be_decoded].permute(0, 2, 3, 1) |
|
.scatter(3, ids_to_score[to_be_decoded].unsqueeze(1).unsqueeze(1) |
|
.repeat(1, *ctc_states_local.shape[1:3], 1), ctc_states_local) |
|
.permute(0, 3, 1, 2)) |
|
|
|
|
|
self.tmp_ctc_scores[:, self.first_timestamp_token_id:] = self.tmp_ctc_scores.max(dim=1).values[:, None] |
|
ctc_scores = self.tmp_ctc_scores - self.ctc_score_prev |
|
|
|
next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores |
|
|
|
if self.debug: |
|
self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids_orig) |
|
|
|
return next_token_scores |
|
|
|
|
|
class LogSoftmaxProcessor(LogitsProcessor): |
|
def __init__( |
|
self, |
|
): |
|
super().__init__() |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
scores = torch.nn.functional.log_softmax(scores, dim=-1) |
|
return scores |
|
|
|
|
|
class GreedyCTCDecoder(torch.nn.Module): |
|
def __init__(self, tokenizer, blank=0): |
|
super().__init__() |
|
self.blank = blank |
|
self.tokenizer = tokenizer |
|
|
|
def forward(self, emission: torch.Tensor) -> List[str]: |
|
"""Given a sequence emission over labels, get the best path |
|
Args: |
|
emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. |
|
|
|
Returns: |
|
List[str]: The resulting transcript |
|
""" |
|
indices = torch.argmax(emission, dim=-1) |
|
indices = [torch.unique_consecutive(index, dim=-1) for index in indices] |
|
indices = [index[index != self.blank] for index in indices] |
|
indices = torch.nn.utils.rnn.pad_sequence(indices, batch_first=True, |
|
padding_value=self.tokenizer.pad_token_id) |
|
indices[indices >= len(self.tokenizer)] = self.tokenizer.unk_token_id |
|
return indices |
|
|
|
|
|
def ctc_greedy_decode(logits: torch.Tensor, blank, pad_token_id) -> torch.Tensor: |
|
idxs = torch.argmax(logits, dim=-1) |
|
for i, prediction in enumerate(idxs): |
|
deduplicated = [k for k, g in it.groupby(prediction) if k != blank] |
|
idxs[i, : len(deduplicated)] = torch.tensor(deduplicated) |
|
idxs[i, len(deduplicated):] = pad_token_id |
|
return idxs |
|
|