|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import transformers |
|
import transformers.modeling_outputs |
|
from transformers.models.whisper import modeling_whisper as whisper |
|
|
|
|
|
class WhisperEncoder(whisper.WhisperEncoder): |
|
""" |
|
Encoder portion of OpenAI's Whisper model. |
|
|
|
This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes: |
|
1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder |
|
2. allow less than 30 second of audio padding to be passed in: |
|
- relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal |
|
- embed_pos is now sliced to match the length of `inputs_embeds` |
|
|
|
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py |
|
""" |
|
|
|
base_model_prefix = "model.encoder" |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
attention_mask=None, |
|
head_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
expected_seq_length = ( |
|
self.config.max_source_positions |
|
* self.conv1.stride[0] |
|
* self.conv2.stride[0] |
|
) |
|
if input_features.shape[-1] > expected_seq_length: |
|
raise ValueError( |
|
f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." |
|
) |
|
|
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) |
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) |
|
|
|
inputs_embeds = inputs_embeds.permute(0, 2, 1) |
|
embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)] |
|
|
|
hidden_states = inputs_embeds + embed_pos |
|
hidden_states = nn.functional.dropout( |
|
hidden_states, p=self.dropout, training=self.training |
|
) |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
if head_mask is not None: |
|
assert head_mask.size()[0] == ( |
|
len(self.layers) |
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
|
|
|
for idx, encoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
to_drop = False |
|
if self.training: |
|
dropout_probability = torch.rand([]) |
|
if dropout_probability < self.layerdrop: |
|
to_drop = True |
|
|
|
if to_drop: |
|
layer_outputs = (None, None) |
|
else: |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
None, |
|
(head_mask[idx] if head_mask is not None else None), |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
None, |
|
layer_head_mask=( |
|
head_mask[idx] if head_mask is not None else None |
|
), |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, encoder_states, all_attentions] |
|
if v is not None |
|
) |
|
return transformers.modeling_outputs.BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_states, |
|
attentions=all_attentions, |
|
) |
|
|