DiCoW_v3_MLC / utils.py
Lakoc's picture
Upload DiCoWForConditionalGeneration
702de8f verified
raw
history blame
4.95 kB
from typing import Optional
import torch
from transformers import WhisperTimeStampLogitsProcessor
def remove_fake_elements(inputs, per_group_sizes):
max_spks = per_group_sizes.max()
number_of_groups = per_group_sizes.shape[0]
outputs = []
inputs = inputs.view(number_of_groups, max_spks, *inputs.shape[1:])
for i, group_size in enumerate(per_group_sizes):
outputs.append(inputs[i, :group_size])
outputs = torch.cat(outputs, dim=0)
return outputs
class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor):
def __init__(
self, generate_config, begin_index: Optional[int] = None,
_detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id
# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)
num_forced_ids = (
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
)
self.begin_index = begin_index or (num_forced_ids + 1)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
self.min_initial_timestamp_index = getattr(generate_config, "min_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
# self.max_initial_timestamp_index = 50
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores_processed = scores.clone()
scores_processed[:, self.no_timestamps_token_id] = -float("inf")
# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
sampled_tokens = input_ids[k, self.begin_index:]
seq = list(sampled_tokens.tolist())
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
scores_processed[k, self.timestamp_begin:] = -float("inf")
else: # cannot be normal text tokens
scores_processed[k, : self.eos_token_id] = -float("inf")
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1
scores_processed[k, self.timestamp_begin: timestamp_last] = -float("inf")
# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index:
eos_scores = scores_processed[:, self.eos_token_id].clone()
scores_processed[:, : self.timestamp_begin] = -float("inf")
scores_processed[:, self.eos_token_id] = eos_scores
if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores_processed[:, last_allowed + 1:] = -float("inf")
if self.min_initial_timestamp_index is not None:
first_allowed = self.timestamp_begin + self.min_initial_timestamp_index
scores_processed[:, self.timestamp_begin:first_allowed] = -float("inf")
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1)
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin:].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores_processed[k, : self.timestamp_begin] = -float("inf")
return scores_processed