|
from typing import Optional |
|
|
|
import torch |
|
from transformers import WhisperTimeStampLogitsProcessor |
|
|
|
|
|
def remove_fake_elements(inputs, per_group_sizes): |
|
max_spks = per_group_sizes.max() |
|
number_of_groups = per_group_sizes.shape[0] |
|
outputs = [] |
|
inputs = inputs.view(number_of_groups, max_spks, *inputs.shape[1:]) |
|
for i, group_size in enumerate(per_group_sizes): |
|
outputs.append(inputs[i, :group_size]) |
|
outputs = torch.cat(outputs, dim=0) |
|
return outputs |
|
|
|
|
|
class WhisperTimeStampLogitsProcessorCustom(WhisperTimeStampLogitsProcessor): |
|
def __init__( |
|
self, generate_config, begin_index: Optional[int] = None, |
|
_detect_timestamp_from_logprob: Optional[bool] = None |
|
): |
|
self.no_timestamps_token_id = generate_config.no_timestamps_token_id |
|
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 |
|
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id |
|
|
|
|
|
self._detect_timestamp_from_logprob = ( |
|
_detect_timestamp_from_logprob |
|
if _detect_timestamp_from_logprob is not None |
|
else getattr(generate_config, "_detect_timestamp_from_logprob", True) |
|
) |
|
|
|
num_forced_ids = ( |
|
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 |
|
) |
|
self.begin_index = begin_index or (num_forced_ids + 1) |
|
|
|
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) |
|
self.min_initial_timestamp_index = getattr(generate_config, "min_initial_timestamp_index", None) |
|
|
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
scores_processed = scores.clone() |
|
scores_processed[:, self.no_timestamps_token_id] = -float("inf") |
|
|
|
|
|
for k in range(input_ids.shape[0]): |
|
sampled_tokens = input_ids[k, self.begin_index:] |
|
seq = list(sampled_tokens.tolist()) |
|
|
|
last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin |
|
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin |
|
|
|
if last_was_timestamp: |
|
if penultimate_was_timestamp: |
|
scores_processed[k, self.timestamp_begin:] = -float("inf") |
|
else: |
|
scores_processed[k, : self.eos_token_id] = -float("inf") |
|
|
|
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)] |
|
if timestamps.numel() > 0: |
|
|
|
|
|
if last_was_timestamp and not penultimate_was_timestamp: |
|
timestamp_last = timestamps[-1] |
|
else: |
|
|
|
timestamp_last = timestamps[-1] + 1 |
|
|
|
scores_processed[k, self.timestamp_begin: timestamp_last] = -float("inf") |
|
|
|
|
|
if input_ids.shape[1] == self.begin_index: |
|
eos_scores = scores_processed[:, self.eos_token_id].clone() |
|
scores_processed[:, : self.timestamp_begin] = -float("inf") |
|
scores_processed[:, self.eos_token_id] = eos_scores |
|
|
|
if self.max_initial_timestamp_index is not None: |
|
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index |
|
scores_processed[:, last_allowed + 1:] = -float("inf") |
|
if self.min_initial_timestamp_index is not None: |
|
first_allowed = self.timestamp_begin + self.min_initial_timestamp_index |
|
scores_processed[:, self.timestamp_begin:first_allowed] = -float("inf") |
|
|
|
|
|
logprobs = torch.nn.functional.log_softmax(scores_processed.float(), dim=-1) |
|
for k in range(input_ids.shape[0]): |
|
timestamp_logprob = logprobs[k, self.timestamp_begin:].logsumexp(dim=-1) |
|
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() |
|
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob: |
|
scores_processed[k, : self.timestamp_begin] = -float("inf") |
|
|
|
return scores_processed |
|
|