from typing import Optional import torch from torch import nn from transformers.modeling_outputs import CausalLMOutput, BaseModelOutput from transformers.models.whisper.modeling_whisper import WhisperEncoder, WhisperEncoderLayer, WHISPER_ATTENTION_CLASSES from .config import DiCoWConfig class CustomLinear(nn.Linear): def __init__(self, *args, init_eye_val=0.0, is_diagonal=False, **kwargs): super().__init__(*args, **kwargs) self.init_eye_val = init_eye_val class CustomDiagonalLinear(nn.Module): def __init__(self, d_model, bias=True, init_eye_val=0.0): super().__init__() self.init_eye_val = init_eye_val self.weight = nn.Parameter(torch.full((d_model,), init_eye_val)) self.bias = nn.Parameter(torch.zeros(d_model)) if bias else None def forward(self, input): out = input * self.weight if self.bias is not None: out += self.bias return out class FDDT(nn.Module): def __init__(self, d_model, non_target_rate=0.01, is_diagonal=False, bias_only=False, use_silence=True, use_target=True, use_overlap=True, use_non_target=True, use_interaction=False, scb_module: Optional[nn.Module] = None, ): super().__init__() if use_target: self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model, d_model, bias=True, init_eye_val=1.0)) if use_non_target: self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear( d_model, d_model, bias=True, init_eye_val=non_target_rate)) if use_overlap: self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear(d_model, d_model, bias=True, init_eye_val=1.0)) if use_silence: self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( CustomDiagonalLinear(d_model, bias=True, init_eye_val=non_target_rate) if is_diagonal else CustomLinear( d_model, d_model, bias=True, init_eye_val=non_target_rate)) if use_interaction: self.scb = scb_module if scb_module is not None else (nn.Parameter(torch.zeros(d_model)) if bias_only else ( CustomDiagonalLinear(d_model, bias=True, init_eye_val=1.0) if is_diagonal else CustomLinear( d_model, d_model, bias=True, init_eye_val=1.0))) self.use_silence = use_silence self.use_target = use_target self.use_overlap = use_overlap self.use_non_target = use_non_target self.use_interaction = use_interaction self.bias_only = bias_only @staticmethod def mask_out_non_interaction_signal(hidden_states, mask): mask = torch.round(mask).bool() masked_hidden_states = hidden_states * mask return masked_hidden_states def forward(self, hidden_states, stno_mask): stno_mask = stno_mask.to(hidden_states.device)[..., None] if self.bias_only: if self.use_silence: hidden_states += stno_mask[:, 0, ...] * self.silence_linear if self.use_target: hidden_states += stno_mask[:, 1, ...] * self.target_linear if self.use_non_target: hidden_states += stno_mask[:, 2, ...] * self.non_target_linear if self.use_overlap: hidden_states += stno_mask[:, 3, ...] * self.overlap_linear if self.use_interaction: hidden_states += stno_mask[:, 4, ...] * self.scb else: orig_hidden_states = hidden_states hidden_states = (self.silence_linear( orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \ (self.target_linear( orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \ (self.non_target_linear( orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2, :] + \ (self.overlap_linear( orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :] + \ (self.scb( self.mask_out_non_interaction_signal(orig_hidden_states, stno_mask[:, 4, :])) * stno_mask[:, 4, :] if self.use_interaction else ( 0 if stno_mask.size( 1) == 4 else orig_hidden_states * stno_mask[:, 4, :])) return hidden_states class DiCoWEncoder(WhisperEncoder): config_class = DiCoWConfig def __init__(self, config: DiCoWConfig): super().__init__(config) self.ctc_weight = config.ctc_weight if config.additional_layer and self.ctc_weight > 0.0: self.additional_layer = WhisperEncoderLayer(config) if config.additional_self_attention_layer and self.ctc_weight > 0.0: self.additional_self_attention_layer = WHISPER_ATTENTION_CLASSES[config._attn_implementation]( embed_dim=config.d_model, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) if config.sub_sample and self.ctc_weight > 0.0: self.subsample_conv1 = nn.Conv1d( in_channels=config.d_model, out_channels=config.d_model, kernel_size=3, stride=2, padding=1, bias=False, ) self.subsample_conv2 = nn.Conv1d( in_channels=config.d_model, out_channels=config.d_model, kernel_size=3, stride=2, padding=1, bias=False, ) if self.ctc_weight > 0.0: self.lm_head = nn.Linear(config.d_model, config.vocab_size + 1, bias=False) self.final_dropout = nn.Dropout(config.final_dropout) if config.use_fddt: num_fddts = self.config.apply_fddt_to_n_layers if self.config.apply_fddt_to_n_layers != -1 else len( self.layers) self.initial_fddt = FDDT(config.d_model, non_target_rate=config.non_target_fddt_value, is_diagonal=config.fddt_is_diagonal, bias_only=config.fddt_bias_only, use_silence=config.fddt_use_silence, use_target=config.fddt_use_target, use_overlap=config.fddt_use_overlap, use_non_target=config.fddt_use_non_target) is_mt = config.mt_num_speakers > 1 num_scbs = (self.config.scb_layers if self.config.scb_layers != -1 else len( self.layers)) if is_mt else 0 self.scbs_identity_layers = config.encoder_layers - num_scbs self.fddts = nn.ModuleList([ FDDT(config.d_model, non_target_rate=1.0, is_diagonal=config.fddt_is_diagonal, bias_only=config.fddt_bias_only, use_silence=config.fddt_use_silence, use_target=config.fddt_use_target, use_overlap=config.fddt_use_overlap, use_non_target=config.fddt_use_non_target, use_interaction=is_mt, ) for i in range(num_fddts) ]) self.first_task_token = self.config.vocab_size - 30 * 50 - 1 - 6 # 30 seconds of 50 Hz timestamps -1 to get to 0.0 and -6 number of tasks self.post_init() @classmethod def _load_pretrained_model( cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, **kwargs ): for key in list(state_dict.keys()): if key.startswith("encoder."): state_dict[key[8:]] = state_dict.pop(key) loaded_keys.remove(key) loaded_keys.append(key[8:]) output = super()._load_pretrained_model( model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, **kwargs ) return output def get_loss(self, logits, labels): if labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") if self.config.remove_timestamps_from_ctc: labels = torch.nn.utils.rnn.pad_sequence([label[label < self.first_task_token] for label in labels], padding_value=-100).T input_lengths = torch.full((logits.shape[0],), fill_value=logits.shape[1], device=logits.device) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) # flattened_targets = labels_enc.masked_select(labels_mask) # ctc_loss doesn't support fp16 log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) with torch.backends.cudnn.flags(enabled=True): ctc_loss = nn.functional.ctc_loss( log_probs, labels, input_lengths, target_lengths, blank=logits.shape[-1] - 1, reduction=self.config.ctc_loss_reduction, zero_infinity=True, ) return ctc_loss def forward( self, input_features, attention_mask=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, stno_mask=None, per_group_sizes=None ): # For MT-ASR the input has shape (B X S) x F x T # we can use torch.view(B, S, F, -1) to obtain # new tensor with speaker dim expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] if input_features.shape[-1] != expected_seq_length: if input_features.shape[-1] > expected_seq_length: return CausalLMOutput( logits=None, hidden_states=None, attentions=None, ) else: raise ValueError( f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." ) output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight if hasattr(self, "shift_embeds") and self.shift_embeds: embed_pos = embed_pos[ torch.clamp(((stno_mask[:, 1, :] + stno_mask[:, 3, :]).cumsum(dim=-1) - 1), min=0).to(torch.long)] if self.config.use_fddt: inputs_embeds = self.initial_fddt(inputs_embeds, stno_mask) hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True if self.config.use_fddt and idx < len(self.fddts): hidden_states = self.fddts[idx](hidden_states, stno_mask) if to_drop: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, None, (head_mask[idx] if head_mask is not None else None), output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, None, layer_head_mask=(head_mask[idx] if head_mask is not None else None), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: outputs = tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) else: outputs = BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) if hasattr(self, "additional_layer"): inter_output, = self.additional_layer( outputs.last_hidden_state, attention_mask=None, output_attentions=output_attentions, layer_head_mask=None, ) elif hasattr(self, "additional_self_attention_layer"): inter_output, _, __ = self.additional_self_attention_layer( outputs.last_hidden_state, attention_mask=None, output_attentions=output_attentions, layer_head_mask=None, ) else: inter_output = outputs.last_hidden_state inter_output = self.final_dropout(inter_output) if hasattr(self, "subsample_conv2"): inter_output = self.subsample_conv2(self.subsample_conv1(inter_output.transpose(1, 2))).transpose(1, 2) if self.ctc_weight > 0.0: logits = self.lm_head(inter_output) else: logits = None return CausalLMOutput( logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )