mrprimenotes commited on
Commit
46a0628
·
verified ·
1 Parent(s): f57afeb

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -8
model.py CHANGED
@@ -271,7 +271,7 @@ class WhisperAttention(nn.Module):
271
  bias: bool = True,
272
  is_causal: bool = False,
273
  layer_idx: Optional[int] = None,
274
- config: Optional[WhisperConfig] = None,
275
  ):
276
  super().__init__()
277
  self.embed_dim = embed_dim
@@ -616,7 +616,7 @@ WHISPER_ATTENTION_CLASSES = {
616
 
617
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
618
  class WhisperEncoderLayer(nn.Module):
619
- def __init__(self, config: WhisperConfig):
620
  super().__init__()
621
  self.embed_dim = config.d_model
622
 
@@ -686,7 +686,7 @@ class WhisperEncoderLayer(nn.Module):
686
 
687
 
688
  class WhisperDecoderLayer(nn.Module):
689
- def __init__(self, config: WhisperConfig, layer_idx: int = None):
690
  super().__init__()
691
  self.embed_dim = config.d_model
692
 
@@ -803,7 +803,7 @@ class WhisperDecoderLayer(nn.Module):
803
 
804
 
805
  class WhisperPreTrainedModel(PreTrainedModel):
806
- config_class = WhisperConfig
807
  base_model_prefix = "model"
808
  main_input_name = "input_features"
809
  supports_gradient_checkpointing = True
@@ -982,7 +982,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
982
  config: WhisperConfig
983
  """
984
 
985
- def __init__(self, config: WhisperConfig):
986
  super().__init__(config)
987
  self.dropout = config.dropout
988
  self.layerdrop = config.encoder_layerdrop
@@ -1271,7 +1271,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
1271
 
1272
  main_input_name = "input_ids"
1273
 
1274
- def __init__(self, config: WhisperConfig):
1275
  super().__init__(config)
1276
  self.dropout = config.dropout
1277
  self.layerdrop = config.decoder_layerdrop
@@ -1674,7 +1674,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
1674
  WHISPER_START_DOCSTRING,
1675
  )
1676
  class WhisperModel(WhisperPreTrainedModel):
1677
- def __init__(self, config: WhisperConfig):
1678
  super().__init__(config)
1679
 
1680
  self.encoder = WhisperEncoder(config)
@@ -1849,7 +1849,7 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
1849
  base_model_prefix = "model"
1850
  _tied_weights_keys = ["proj_out.weight"]
1851
 
1852
- def __init__(self, config: WhisperConfig):
1853
  super().__init__(config)
1854
  self.model = WhisperModel(config)
1855
  self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
 
271
  bias: bool = True,
272
  is_causal: bool = False,
273
  layer_idx: Optional[int] = None,
274
+ config: Optional[CustomWhisperConfig] = None,
275
  ):
276
  super().__init__()
277
  self.embed_dim = embed_dim
 
616
 
617
  # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
618
  class WhisperEncoderLayer(nn.Module):
619
+ def __init__(self, config: CustomWhisperConfig):
620
  super().__init__()
621
  self.embed_dim = config.d_model
622
 
 
686
 
687
 
688
  class WhisperDecoderLayer(nn.Module):
689
+ def __init__(self, config: CustomWhisperConfig, layer_idx: int = None):
690
  super().__init__()
691
  self.embed_dim = config.d_model
692
 
 
803
 
804
 
805
  class WhisperPreTrainedModel(PreTrainedModel):
806
+ config_class = CustomWhisperConfig
807
  base_model_prefix = "model"
808
  main_input_name = "input_features"
809
  supports_gradient_checkpointing = True
 
982
  config: WhisperConfig
983
  """
984
 
985
+ def __init__(self, config: CustomWhisperConfig):
986
  super().__init__(config)
987
  self.dropout = config.dropout
988
  self.layerdrop = config.encoder_layerdrop
 
1271
 
1272
  main_input_name = "input_ids"
1273
 
1274
+ def __init__(self, config: CustomWhisperConfig):
1275
  super().__init__(config)
1276
  self.dropout = config.dropout
1277
  self.layerdrop = config.decoder_layerdrop
 
1674
  WHISPER_START_DOCSTRING,
1675
  )
1676
  class WhisperModel(WhisperPreTrainedModel):
1677
+ def __init__(self, config: CustomWhisperConfig):
1678
  super().__init__(config)
1679
 
1680
  self.encoder = WhisperEncoder(config)
 
1849
  base_model_prefix = "model"
1850
  _tied_weights_keys = ["proj_out.weight"]
1851
 
1852
+ def __init__(self, config: CustomWhisperConfig):
1853
  super().__init__(config)
1854
  self.model = WhisperModel(config)
1855
  self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)