# pylint: skip-file # Copied from: https://github.com/espnet/espnet/blob/master/espnet/nets/ctc_prefix_score.py import itertools as it from typing import List import pandas as pd import torch from transformers import LogitsProcessor, PreTrainedTokenizer class CTCPrefixScore(object): """Compute CTC label sequence scores which is based on Algorithm 2 in WATANABE et al. "HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION," but extended to efficiently compute the label probabilities for multiple hypotheses simultaneously See also Seki et al. "Vectorized Beam Search for CTC-Attention-Based Speech Recognition," In INTERSPEECH (pp. 3825-3829), 2019. """ def __init__(self, x, blank, eos): self.logzero = -1e10 self.blank = blank self.eos = eos self.input_length = x.shape[1] self.batch_size = x.shape[0] self.x = x self.device = x.device # Preallocate `r` and `xs` tensors # `num_labels` will be set dynamically in __call__ but preallocated with maximum capacity self.max_num_labels = x.shape[2] # Set to a max value that can be dynamically resized self.r = torch.full((self.batch_size, self.input_length, 2, self.max_num_labels), self.logzero, device=self.device) self.xs = torch.full((self.batch_size, self.input_length, self.max_num_labels), self.logzero, device=self.device) def initial_state(self): """Obtain an initial CTC state.""" # Create initial CTC state tensor and use in-place operations to fill r = torch.full((self.batch_size, self.input_length, 2), self.logzero, device=self.device) r[..., 1] = torch.cumsum(self.x[..., self.blank], dim=1) s = torch.zeros((self.batch_size, 1), device=self.device) return r, s def _resize_tensors(self, number_of_current_samples, num_labels): if self.r.shape[0] != number_of_current_samples: self.r = self.r[:number_of_current_samples, ...] self.xs = self.xs[:number_of_current_samples, ...] if self.r.shape[3] != num_labels: self.r = self.r[:, :, :, :num_labels].fill_(self.logzero) self.xs = self.xs[:, :, :num_labels].fill_(self.logzero) else: self.r.fill_(self.logzero) self.xs.fill_(self.logzero) def _initialize_r(self, decoded_len): mask = (decoded_len == 0) self.r[mask, 0, 0, :] = self.xs[mask, 0] def _compute_log_phi(self, r_sum, cs, last, decoded_len, r_prev): # Expand r_sum for num_labels and initialize log_phi log_phi = r_sum[..., None].expand(-1, -1, cs.shape[1]) # Create mask for cases where `decoded_len > 0` and to identify where `c == last[i]` for all `i` non_zero_mask = (decoded_len > 0) label_match_mask = (cs == last.unsqueeze(1)) # Update log_phi where both `decoded_len > 0` and `c == last[i]` log_phi = torch.where((non_zero_mask.unsqueeze(1) & label_match_mask)[:, None, :], r_prev[..., 1:2], log_phi) return log_phi def _compute_log_psi(self, decoded_len, log_phi, x_current): """This function computes forward probabilities log(r_t^n(h)), log(r_t^b(h)), and log prefix probabilities log(psi) for all labels in the batch. :param decoded_len: tensor of shape (batch_size,) containing the length of the decoded sequence :param log_phi: tensor of shape (batch_size, input_length, num_labels) containing the forward probabilities :param x_current: tensor of shape (batch_size, input_length, num_labels) containing the input frame :return log_psi: tensor of shape (batch_size,num_labels) containing the log prefix probabilities """ B, T, V = log_phi.shape start = torch.clamp(decoded_len, min=1) # Ensure start is at least 1 to avoid out-of-bounds # Initialize log_psi with the start position of r[:, start - 1, 0, :] log_psi = self.r[torch.arange(B), start - 1, 0, :] # Mask for handling sequence lengths based on decoded_len mask_t = torch.arange(1, T, device=decoded_len.device).expand(B, T - 1) >= decoded_len.unsqueeze(1) # Accumulate log_psi only up to the last valid time step for each sequence log_psi = torch.logaddexp(log_psi, torch.logsumexp( torch.where(mask_t.unsqueeze(-1), log_phi[:, :-1] + self.xs[:, 1:], self.logzero), dim=1)) start = torch.clamp(decoded_len, 1) # TODO: Vectorize this loop by compute suffix xs and multiplying with log_phi # xs = self.xs[:,1:,:].clone() # xs_cum = torch.cumsum(xs, dim=1) # xs_cum_expanded = xs_cum.unsqueeze(1).repeat(1, T-1, 1, 1) # xs_u = (xs_cum_expanded - torch.nn.functional.pad(xs_cum[:,:-1,:], (0,0,1,0), value=0).unsqueeze(2).repeat(1, 1,T-1,1)).permute(0,2,1,3) # # phis_new = log_phi[:,:-1].clone() # phis_new[:, 0] = torch.logaddexp(phis_new[:, 0], self.r[:, 0, 0, :]) # phis_new = phis_new.unsqueeze(1).repeat(1, T-1, 1, 1) # causal_mask = torch.ones((T-1,T-1), dtype=torch.bool, device=self.device).tril().unsqueeze(0).unsqueeze(-1).repeat(B,1,1,1) # mask = causal_mask & mask_t.unsqueeze(2).unsqueeze(-1) # r_zero = torch.logsumexp(torch.where(mask, xs_u + phis_new, self.logzero), dim=2) # self.r[:,1:,0] = r_zero for t in range(start.min(), self.input_length): should_decode = decoded_len <= t self.r[:, t, 0] = torch.logaddexp(self.r[:, t - 1, 0], log_phi[:, t - 1]) + self.xs[:, t] self.r[:, t, 1] = ( torch.logaddexp(self.r[:, t - 1, 0], self.r[:, t - 1, 1]) + x_current[:, t, self.blank][:, None] ) if ~should_decode.any(): self.r[:, t] = torch.where(should_decode.unsqueeze(-1).unsqueeze(-1), self.r[:, t], self.logzero) return log_psi def _update_log_psi_with_eos(self, log_psi, cs, r_sum): # Update log_psi for eos positions eos_mask = (cs == self.eos) log_psi[eos_mask] = r_sum[:, -1].unsqueeze(1).expand_as(log_psi)[eos_mask] # Exclude blank probabilities if eos is not the blank if self.eos != self.blank: blank_mask = (cs == self.blank) log_psi[blank_mask] = self.logzero return log_psi def __call__(self, y, cs, decoded_len, samples_to_be_decoded, r_prev): """Compute CTC prefix scores for next labels :param y : prefix label sequence :param cs : array of next labels :param r_prev: previous CTC state :return ctc_scores, ctc_states """ # initialize CTC states # output_length = y.shape[1] - 1 # ignore sos # new CTC states are prepared as a frame x (n or b) x n_labels tensor # that corresponds to r_t^n(h) and r_t^b(h). # Dynamically resize r and xs to match num_labels if necessary num_labels = cs.shape[1] number_of_current_samples = cs.shape[0] self._resize_tensors(number_of_current_samples, num_labels) # Create a view of the current input frame x_current = self.x[samples_to_be_decoded] self.xs = torch.gather(x_current, 2, cs.unsqueeze(1).expand(-1, self.input_length, -1)) # Initialize r for the first frame self._initialize_r(decoded_len) # prepare forward probabilities for the last label r_sum = torch.logaddexp(r_prev[:, :, 0], r_prev[:, :, 1]) # log(r_t^n(g) + r_t^b(g)) last = y[:, -1] # precompute log_phi log_phi = self._compute_log_phi(r_sum, cs, last, decoded_len, r_prev) # compute forward probabilities log(r_t^n(h)), log(r_t^b(h)), # and log prefix probabilities log(psi) log_psi = self._compute_log_psi(decoded_len, log_phi, x_current) # get P(...eos|X) that ends with the prefix itself log_psi = self._update_log_psi_with_eos(log_psi, cs, r_sum) # return the log prefix probability and CTC states, where the label axis # of the CTC states is moved to the first axis to slice it easily return log_psi, self.r class CTCRescorerLogitsProcessor(LogitsProcessor): def __init__( self, encoder_logits: torch.FloatTensor, encoder_output_lens: torch.Tensor, blank_token_id: int, pad_token_id: int, eos_token_id: int, bos_token_id: int, tokenizer: PreTrainedTokenizer, ctc_margin: int, ctc_weight: float, num_beams: int, debug: bool = False, ctc_tokens_to_score: int = 500 ): super().__init__() same_logits = torch.tensor(list((tokenizer.upper_cased_tokens.items()))) logits = torch.nn.functional.log_softmax(encoder_logits, dim=-1) logits[..., same_logits[:, 1]] = logits[..., same_logits[:, 0]] self.logits = logits self.ctc_prefix_scorer = CTCPrefixScore( self.logits, blank_token_id, eos_token_id, ) self.batch_size = logits.shape[0] self.input_length = logits.shape[1] self.num_tokens = logits.shape[2] self.device = logits.device self.ctc_weight = ctc_weight self.num_beams = num_beams self.ctc_state_prev, self.ctc_score_prev = self.ctc_prefix_scorer.initial_state() self.eos_token_id = eos_token_id self.bos_token_id = bos_token_id self.tokenizer = tokenizer self.pad_token_id = pad_token_id self.blank_token_id = blank_token_id self.debug = False self.first_timestamp_token_id = tokenizer.get_vocab()["<|0.00|>"] self.tmp_ctc_scores = torch.empty((self.batch_size, self.num_tokens - 1), device=self.device) self.tmp_ctc_states = torch.empty((self.batch_size, self.num_tokens - 1, self.input_length, 2), device=self.device) self.ctc_tokens_to_score = ctc_tokens_to_score def analyze_predictions(self, scores, ctc_scores, next_token_scores, input_ids, k=10): print("\n" + "#" * 100) batch_size = input_ids.shape[0] best_att_ids = scores.topk(k=k, dim=1) ctc_scores[:, self.first_timestamp_token_id:] = self.ctc_prefix_scorer.logzero best_ctc_ids = ctc_scores.topk(k=k, dim=1) best_ids = next_token_scores.topk(k=k, dim=1) decoded_prefixes = self.tokenizer.batch_decode( input_ids, decode_with_timestamps=True, skip_special_tokens=False ) def prepare_and_decode(best_ids_tensor): new_tensor = torch.zeros((batch_size, k * 2), dtype=torch.long) new_tensor[:, 0::2] = best_ids_tensor.indices new_tensor[:, 1::2] = self.tokenizer.vocab['#'] # Flatten to (batch_size * k, 2) flat_tensor = new_tensor.view(-1, 2) decoded = self.tokenizer.batch_decode( flat_tensor, decode_with_timestamps=True, skip_special_tokens=False ) # Reshape back to (batch_size, k) decoded = [(decoded[i * k:(i + 1) * k]) for i in range(batch_size)] return decoded decoded_att = prepare_and_decode(best_att_ids) decoded_ctc = prepare_and_decode(best_ctc_ids) decoded_next = prepare_and_decode(best_ids) for idx in range(batch_size): print("-" * 80) print(f"HYPOTHESIS {idx}") print("\nPREFIX:") print(decoded_prefixes[idx]) def print_with_pandas(tokens, scores, title): df = pd.DataFrame([tokens, [f"{s.item():.2f}" for s in scores]]) df.index = [f"{title}", "Score"] print(f"\n{title}:") print(df.to_string(index=True, header=False)) print_with_pandas(decoded_att[idx], best_att_ids.values[idx], "ATT_TOKENS") print_with_pandas(decoded_ctc[idx], best_ctc_ids.values[idx], "CTC_TOKENS") print_with_pandas(decoded_next[idx], best_ids.values[idx], "NEXT_TOKENS") print(f"\nCTC_EOS: {ctc_scores[idx, self.tokenizer.eos_token_id].item():.2f}") print() print("#" * 100) def update_state(self, best_ids, beam_idx): mask = best_ids < self.first_timestamp_token_id self.ctc_state_prev = torch.where(mask.unsqueeze(-1).unsqueeze(-1), self.tmp_ctc_states[beam_idx, best_ids], self.ctc_state_prev[beam_idx]) self.ctc_score_prev = torch.where(mask.unsqueeze(-1), self.tmp_ctc_scores[beam_idx, best_ids].unsqueeze(-1), self.ctc_score_prev[beam_idx]) def __call__(self, input_ids_orig: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: input_ids = input_ids_orig.clone() # Remove prefix from CTC scoring if (input_ids[:, 0] != self.bos_token_id).any(): input_ids = torch.stack( [row[(row == self.bos_token_id).nonzero(as_tuple=True)[0].item():] for row in input_ids]) # Remove task/lang/timestamp tokens from input_ids input_prefix_len = len(self.tokenizer.prefix_tokens) if input_prefix_len > 1: input_ids = input_ids[:, input_prefix_len - 1:] # Setup the first token to be the blank token(sos) input_ids[:, 0] = self.blank_token_id # If there is last token in input_ids timestamp replicate last non-timestamp token which could be potentially even the first token decoded_len = torch.logical_and(input_ids <= self.first_timestamp_token_id, input_ids != self.blank_token_id).sum(dim=1) mask = torch.logical_and(input_ids[:, -1] >= self.first_timestamp_token_id, input_ids[:, -1] != self.blank_token_id) last_non_timestamp_token = torch.gather(input_ids, 1, torch.logical_or(input_ids < self.first_timestamp_token_id, input_ids == self.blank_token_id).sum(dim=1, keepdim=True) - 1) input_ids[mask, -1] = last_non_timestamp_token[mask, 0] # If there is no eos token in the last position, we need to continue decoding to_be_decoded = input_ids[:, -1] != self.eos_token_id self.tmp_ctc_scores[:] = self.ctc_prefix_scorer.logzero input_ids_local = input_ids[to_be_decoded] ids_to_score = torch.topk(scores[:, :self.first_timestamp_token_id], k=self.ctc_tokens_to_score).indices # always score EOS token if not present put on position of last id is_eos_present = (ids_to_score == self.eos_token_id).any(dim=1) ids_to_score[~is_eos_present, self.ctc_tokens_to_score - 1] = self.eos_token_id decoded_len_local = decoded_len[to_be_decoded] ctc_scores_local, ctc_states_local = self.ctc_prefix_scorer(input_ids_local, ids_to_score[to_be_decoded], decoded_len_local, to_be_decoded, self.ctc_state_prev[to_be_decoded]) # As the CTC scorer might run on subset of samples, we need to scatter the results back to the original batch self.tmp_ctc_scores[to_be_decoded] = (self.tmp_ctc_scores[to_be_decoded] .scatter(1, ids_to_score[to_be_decoded], ctc_scores_local)) self.tmp_ctc_states[to_be_decoded] = (self.tmp_ctc_states[to_be_decoded].permute(0, 2, 3, 1) .scatter(3, ids_to_score[to_be_decoded].unsqueeze(1).unsqueeze(1) .repeat(1, *ctc_states_local.shape[1:3], 1), ctc_states_local) .permute(0, 3, 1, 2)) # Set the CTC score for the timestamp tokens to the maximum to prefer them over the rest self.tmp_ctc_scores[:, self.first_timestamp_token_id:] = self.tmp_ctc_scores.max(dim=1).values[:, None] ctc_scores = self.tmp_ctc_scores - self.ctc_score_prev next_token_scores = (1 - self.ctc_weight) * scores + self.ctc_weight * ctc_scores if self.debug: self.analyze_predictions(scores, ctc_scores, next_token_scores, input_ids_orig) return next_token_scores class LogSoftmaxProcessor(LogitsProcessor): def __init__( self, ): super().__init__() def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: scores = torch.nn.functional.log_softmax(scores, dim=-1) return scores class GreedyCTCDecoder(torch.nn.Module): def __init__(self, tokenizer, blank=0): super().__init__() self.blank = blank self.tokenizer = tokenizer def forward(self, emission: torch.Tensor) -> List[str]: """Given a sequence emission over labels, get the best path Args: emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. Returns: List[str]: The resulting transcript """ indices = torch.argmax(emission, dim=-1) # [num_seq,] indices = [torch.unique_consecutive(index, dim=-1) for index in indices] indices = [index[index != self.blank] for index in indices] indices = torch.nn.utils.rnn.pad_sequence(indices, batch_first=True, padding_value=self.tokenizer.pad_token_id) indices[indices >= len(self.tokenizer)] = self.tokenizer.unk_token_id return indices def ctc_greedy_decode(logits: torch.Tensor, blank, pad_token_id) -> torch.Tensor: idxs = torch.argmax(logits, dim=-1) for i, prediction in enumerate(idxs): deduplicated = [k for k, g in it.groupby(prediction) if k != blank] idxs[i, : len(deduplicated)] = torch.tensor(deduplicated) idxs[i, len(deduplicated):] = pad_token_id return idxs