from typing import Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import CrossEntropyLoss import torch.utils.checkpoint import torch.utils.checkpoint from transformers.modeling_outputs import Seq2SeqLMOutput from transformers.models.speech_encoder_decoder.modeling_speech_encoder_decoder import ( shift_tokens_right, ) from transformers.models.whisper.modeling_whisper import ( WhisperEncoder, ) from transformers.models.whisper.modeling_whisper import ( WhisperForConditionalGeneration, shift_tokens_right, WhisperModel, ) from transformers.models.whisper.modeling_whisper import sinusoids from transformers.utils import logging from .config import Seq2SeqLMOutputLosses, Seq2SeqModelOutputLogit, DiCoWConfig from .encoder import CustomLinear, CustomDiagonalLinear, FDDT, DiCoWEncoder from .generation import DiCoWGenerationMixin logging.set_verbosity_debug() logger = logging.get_logger("transformers") class DiCoW(WhisperModel): def __init__(self, config: DiCoWConfig): super().__init__(config) self.encoder = DiCoWEncoder(config) def forward( self, input_features: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, stno_mask: Optional[torch.FloatTensor] = None, per_group_sizes: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutputLosses]: r""" Returns: Example: ```python >>> import torch >>> from transformers import AutoFeatureExtractor, WhisperModel >>> from datasets import load_dataset >>> model = WhisperModel.from_pretrained("openai/whisper-base") >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt") >>> input_features = inputs.input_features >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state >>> list(last_hidden_state.shape) [1, 2, 512] ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if encoder_outputs is None: input_features = self._mask_input_features(input_features, attention_mask=attention_mask) encoder_outputs = self.encoder( input_features, output_attentions=output_attentions, output_hidden_states=True, head_mask=head_mask, return_dict=return_dict, stno_mask=stno_mask, per_group_sizes=per_group_sizes ) # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True # elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): # raise ValueError("encoder_outputs should be of type BaseModelOutput when return_dict=True.") # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs.hidden_states[-1], head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutputLogit( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.hidden_states[-1], encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, encoder_logits=encoder_outputs.logits, ) class DiCoWForConditionalGeneration(DiCoWGenerationMixin, WhisperForConditionalGeneration): config_class = DiCoWConfig def __init__(self, config: DiCoWConfig): super().__init__(config) self.model = DiCoW(config) self.encoder_logits = None self.tokenizer = None self.vad_seek_callback = None self.stno_mask = None self.stno_mask_seek = None # We need this setter as we can't pass a function/method as a config argument. # JSON serialization fails at that point. def set_vad_seek_callback(self, vad_seek_callback): self.vad_seek_callback = vad_seek_callback def set_tokenizer(self, tokenizer): self.tokenizer = tokenizer def _init_weights(self, module): std = self.config.init_std fddt_init = self.config.fddt_init if isinstance(module, CustomLinear): with torch.no_grad(): if fddt_init == 'random': module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.normal_(mean=0.0, std=std) elif fddt_init == 'non-disturbing': module.weight.data = torch.eye(*module.weight.shape).data if module.bias is not None: module.bias.data.zero_() elif fddt_init == 'disparagement': eye = torch.eye(*module.weight.shape) eye *= module.init_eye_val module.weight.data = eye.data if module.bias is not None: module.bias.data.zero_() elif isinstance(module, CustomDiagonalLinear): with torch.no_grad(): if fddt_init == 'random': module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.normal_(mean=0.0, std=std) elif fddt_init == 'non-disturbing': module.weight.data = torch.ones_like(module.weight.data).data if module.bias is not None: module.bias.data.zero_() elif fddt_init == 'disparagement': module.weight.data = module.init_eye_val * torch.ones_like(module.weight.data).data if module.bias is not None: module.bias.data.zero_() elif isinstance(module, FDDT): if module.bias_only: if fddt_init == 'random': module.target_linear.data.normal_(mean=0.0, std=std) module.non_target_linear.data.normal_(mean=0.0, std=std) module.overlap_linear.data.normal_(mean=0.0, std=std) module.silence_linear.data.normal_(mean=0.0, std=std) else: module.target_linear.data.zero_() module.non_target_linear.data.zero_() module.overlap_linear.data.zero_() module.silence_linear.data.zero_() elif isinstance(module, (nn.Linear, nn.Conv1d)): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, WhisperEncoder): with torch.no_grad(): embed_positions = module.embed_positions.weight embed_positions.copy_(sinusoids(*embed_positions.shape)) elif isinstance(module, nn.LayerNorm): module.reset_parameters() elif isinstance(module, nn.MultiheadAttention): module._reset_parameters() elif isinstance(module, nn.ConvTranspose1d): module.reset_parameters() def forward( self, input_features: Optional[torch.FloatTensor] = None, stno_mask: Optional[torch.FloatTensor] = None, per_group_sizes: Optional[torch.LongTensor] = None, attention_mask_enc: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, upp_labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, is_valid: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> import torch >>> from transformers import AutoProcessor, WhisperForConditionalGeneration >>> from datasets import load_dataset >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") >>> input_features = inputs.input_features >>> generated_ids = model.generate(inputs=input_features) >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] >>> transcription ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) outputs = self.model( input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, stno_mask=stno_mask, per_group_sizes=per_group_sizes ) dec_lm_logits = self.proj_out(outputs.last_hidden_state) enc_lm_logits = outputs.encoder_logits loss = None ctc_loss = 0 # remove fake inputs from labels and logits given per group sizes if is_valid is not None: if self.config.ctc_weight > 0.0: enc_lm_logits = enc_lm_logits[is_valid] dec_lm_logits = dec_lm_logits[is_valid] labels = labels[is_valid] upp_labels = upp_labels[is_valid] if labels is not None and self.config.ctc_weight > 0.0: enc_labels = labels.clone() for token in self.tokenizer.prefix_tokens: if (enc_labels[:, 0] == token).all(): enc_labels = enc_labels[:, 1:] enc_labels[enc_labels == self.config.eos_token_id] = -100 ctc_loss = self.get_encoder().get_loss(enc_lm_logits, enc_labels) if labels is not None: loss_fct = CrossEntropyLoss(reduction='none') # move labels to correct device to enable PP labels = labels.to(dec_lm_logits.device) dec_loss1 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) dec_loss2 = loss_fct(dec_lm_logits.view(-1, self.config.vocab_size), upp_labels.reshape(-1)) dec_loss = torch.hstack((dec_loss1[..., None], dec_loss2[..., None])).min(dim=-1).values.mean() loss = (1 - self.config.ctc_weight) * dec_loss + self.config.ctc_weight * ctc_loss if not return_dict: output = (dec_lm_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return Seq2SeqLMOutputLosses( loss=loss, logits=dec_lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, encoder_logits=enc_lm_logits, ) def _get_feat_extract_output_lengths(self, attention_mask: torch.Tensor) -> torch.Tensor: return (self.model.encoder._get_feat_extract_output_lengths(attention_mask) / 4).ceil() def freeze_except(self, prefixes_to_preheat): for name, param in self.named_parameters(): param.requires_grad = False for prefix in prefixes_to_preheat: if name.startswith(prefix): param.requires_grad = True def suppress_interactions(self): """This method suppress final projection in CoAttention blocks to let the original information flow through""" for name, param in self.named_parameters(): if "interaction" in name and "cat_proj" in name: with torch.no_grad(): if "bias" in name: param[:] = 0. else: param[:] *= 0.001