from dataclasses import dataclass from typing import Optional import torch from transformers import WhisperConfig from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput, Seq2SeqModelOutput @dataclass class Seq2SeqLMOutputLosses(Seq2SeqLMOutput): enc_loss: Optional[torch.FloatTensor] = None dec_loss: Optional[torch.FloatTensor] = None encoder_logits: Optional[torch.FloatTensor] = None @dataclass class BaseModelOutputLogit(BaseModelOutput): logits: Optional[torch.FloatTensor] = None @dataclass class Seq2SeqModelOutputLogit(Seq2SeqModelOutput): encoder_logits: Optional[torch.FloatTensor] = None class DiCoWConfig(WhisperConfig): """This is a modified version of the `WhisperEncoder` model from the `transformers` library. The model has been modified to support CTC loss computation in the forward pass.""" model_type = "DiCoW" def __init__( self, ctc_loss_reduction: str = "mean", final_dropout: float = 0.0, ctc_zero_infinity: bool = False, ctc_weight: float = 0.0, blank_token_id: Optional[int] = None, additional_layer: bool = False, additional_self_attention_layer: bool = False, sub_sample: bool = False, use_fddt: bool = True, fddt_is_diagonal: bool = True, fddt_bias_only: bool = False, fddt_use_silence: bool = True, fddt_use_target: bool = True, fddt_use_overlap: bool = True, fddt_use_non_target: bool = True, remove_timestamps_from_ctc: bool = False, apply_fddt_to_n_layers: int = -1, fddt_init: str = 'non-disturbing', # random, non-disturbing, dispargement n_soft_prompts: int = 16, mt_num_speakers: int = 1, non_target_fddt_value: float = 0.0, use_initial_fddt: bool = False, scb_method: str = None, scb_layers: int = -1, **kwargs, ): super().__init__(**kwargs) self.ctc_loss_reduction = ctc_loss_reduction self.final_dropout = final_dropout self.ctc_zero_infinity = ctc_zero_infinity self.ctc_weight = ctc_weight self.blank_token_id = blank_token_id self.additional_layer = additional_layer self.additional_self_attention_layer = additional_self_attention_layer self.sub_sample = sub_sample self.use_fddt = use_fddt self.fddt_is_diagonal = fddt_is_diagonal self.fddt_bias_only = fddt_bias_only self.fddt_use_silence = fddt_use_silence self.fddt_use_target = fddt_use_target self.fddt_use_overlap = fddt_use_overlap self.fddt_use_non_target = fddt_use_non_target self.remove_timestamps_from_ctc = remove_timestamps_from_ctc self.apply_fddt_to_n_layers = apply_fddt_to_n_layers self.fddt_init = fddt_init self.n_soft_prompts = n_soft_prompts self.mt_num_speakers = mt_num_speakers self.non_target_fddt_value = non_target_fddt_value self.use_initial_fddt = use_initial_fddt self.scb_method = scb_method self.scb_layers = scb_layers _HIDDEN_STATES_START_POSITION = 2