Update model.py
Browse files
model.py
CHANGED
@@ -271,7 +271,7 @@ class WhisperAttention(nn.Module):
|
|
271 |
bias: bool = True,
|
272 |
is_causal: bool = False,
|
273 |
layer_idx: Optional[int] = None,
|
274 |
-
config: Optional[
|
275 |
):
|
276 |
super().__init__()
|
277 |
self.embed_dim = embed_dim
|
@@ -616,7 +616,7 @@ WHISPER_ATTENTION_CLASSES = {
|
|
616 |
|
617 |
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
|
618 |
class WhisperEncoderLayer(nn.Module):
|
619 |
-
def __init__(self, config:
|
620 |
super().__init__()
|
621 |
self.embed_dim = config.d_model
|
622 |
|
@@ -686,7 +686,7 @@ class WhisperEncoderLayer(nn.Module):
|
|
686 |
|
687 |
|
688 |
class WhisperDecoderLayer(nn.Module):
|
689 |
-
def __init__(self, config:
|
690 |
super().__init__()
|
691 |
self.embed_dim = config.d_model
|
692 |
|
@@ -803,7 +803,7 @@ class WhisperDecoderLayer(nn.Module):
|
|
803 |
|
804 |
|
805 |
class WhisperPreTrainedModel(PreTrainedModel):
|
806 |
-
config_class =
|
807 |
base_model_prefix = "model"
|
808 |
main_input_name = "input_features"
|
809 |
supports_gradient_checkpointing = True
|
@@ -982,7 +982,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
982 |
config: WhisperConfig
|
983 |
"""
|
984 |
|
985 |
-
def __init__(self, config:
|
986 |
super().__init__(config)
|
987 |
self.dropout = config.dropout
|
988 |
self.layerdrop = config.encoder_layerdrop
|
@@ -1271,7 +1271,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|
1271 |
|
1272 |
main_input_name = "input_ids"
|
1273 |
|
1274 |
-
def __init__(self, config:
|
1275 |
super().__init__(config)
|
1276 |
self.dropout = config.dropout
|
1277 |
self.layerdrop = config.decoder_layerdrop
|
@@ -1674,7 +1674,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|
1674 |
WHISPER_START_DOCSTRING,
|
1675 |
)
|
1676 |
class WhisperModel(WhisperPreTrainedModel):
|
1677 |
-
def __init__(self, config:
|
1678 |
super().__init__(config)
|
1679 |
|
1680 |
self.encoder = WhisperEncoder(config)
|
@@ -1849,7 +1849,7 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
|
|
1849 |
base_model_prefix = "model"
|
1850 |
_tied_weights_keys = ["proj_out.weight"]
|
1851 |
|
1852 |
-
def __init__(self, config:
|
1853 |
super().__init__(config)
|
1854 |
self.model = WhisperModel(config)
|
1855 |
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
|
|
271 |
bias: bool = True,
|
272 |
is_causal: bool = False,
|
273 |
layer_idx: Optional[int] = None,
|
274 |
+
config: Optional[CustomWhisperConfig] = None,
|
275 |
):
|
276 |
super().__init__()
|
277 |
self.embed_dim = embed_dim
|
|
|
616 |
|
617 |
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
|
618 |
class WhisperEncoderLayer(nn.Module):
|
619 |
+
def __init__(self, config: CustomWhisperConfig):
|
620 |
super().__init__()
|
621 |
self.embed_dim = config.d_model
|
622 |
|
|
|
686 |
|
687 |
|
688 |
class WhisperDecoderLayer(nn.Module):
|
689 |
+
def __init__(self, config: CustomWhisperConfig, layer_idx: int = None):
|
690 |
super().__init__()
|
691 |
self.embed_dim = config.d_model
|
692 |
|
|
|
803 |
|
804 |
|
805 |
class WhisperPreTrainedModel(PreTrainedModel):
|
806 |
+
config_class = CustomWhisperConfig
|
807 |
base_model_prefix = "model"
|
808 |
main_input_name = "input_features"
|
809 |
supports_gradient_checkpointing = True
|
|
|
982 |
config: WhisperConfig
|
983 |
"""
|
984 |
|
985 |
+
def __init__(self, config: CustomWhisperConfig):
|
986 |
super().__init__(config)
|
987 |
self.dropout = config.dropout
|
988 |
self.layerdrop = config.encoder_layerdrop
|
|
|
1271 |
|
1272 |
main_input_name = "input_ids"
|
1273 |
|
1274 |
+
def __init__(self, config: CustomWhisperConfig):
|
1275 |
super().__init__(config)
|
1276 |
self.dropout = config.dropout
|
1277 |
self.layerdrop = config.decoder_layerdrop
|
|
|
1674 |
WHISPER_START_DOCSTRING,
|
1675 |
)
|
1676 |
class WhisperModel(WhisperPreTrainedModel):
|
1677 |
+
def __init__(self, config: CustomWhisperConfig):
|
1678 |
super().__init__(config)
|
1679 |
|
1680 |
self.encoder = WhisperEncoder(config)
|
|
|
1849 |
base_model_prefix = "model"
|
1850 |
_tied_weights_keys = ["proj_out.weight"]
|
1851 |
|
1852 |
+
def __init__(self, config: CustomWhisperConfig):
|
1853 |
super().__init__(config)
|
1854 |
self.model = WhisperModel(config)
|
1855 |
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|