mrprimenotes commited on
Commit
7518480
·
verified ·
1 Parent(s): bdd8185

update patch for generation method

Browse files
Files changed (1) hide show
  1. model.py +23 -18
model.py CHANGED
@@ -1801,25 +1801,13 @@ class WhisperModel(WhisperPreTrainedModel):
1801
  encoder_attentions=encoder_outputs.attentions,
1802
  )
1803
 
 
1804
 
1805
- @add_start_docstrings(
1806
- "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
1807
- WHISPER_START_DOCSTRING,
1808
- )
1809
- class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
1810
- base_model_prefix = "model"
1811
- _tied_weights_keys = ["proj_out.weight"]
1812
-
1813
- def __init__(self, config: CustomWhisperConfig):
1814
- super().__init__(config)
1815
- self.model = WhisperModel(config)
1816
- self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
1817
- self.max_target_positions = config.max_target_positions
1818
-
1819
- # Initialize weights and apply final processing
1820
- self.post_init()
1821
-
1822
- # CUSTOM (patch the generation method)
1823
  def generate(
1824
  self,
1825
  input_features: Optional[torch.Tensor] = None,
@@ -2142,6 +2130,23 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
2142
  for i in range(len(outputs.encoder_hidden_states))
2143
  )
2144
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2145
 
2146
  def _pad_to_max_length(
2147
  current_segments,
 
1801
  encoder_attentions=encoder_outputs.attentions,
1802
  )
1803
 
1804
+
1805
 
1806
+ # CUSTOM (patch the generation method)
1807
+ class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1808
+ def __init__(self, *args, **kwargs):
1809
+ super().__init__(*args, **kwargs)
1810
+
 
 
 
 
 
 
 
 
 
 
 
 
 
1811
  def generate(
1812
  self,
1813
  input_features: Optional[torch.Tensor] = None,
 
2130
  for i in range(len(outputs.encoder_hidden_states))
2131
  )
2132
  return outputs
2133
+
2134
+ @add_start_docstrings(
2135
+ "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
2136
+ WHISPER_START_DOCSTRING,
2137
+ )
2138
+ class CustomWhisperForConditionalGeneration(CustomWhisperGenerationMixin, WhisperPreTrainedModel):
2139
+ base_model_prefix = "model"
2140
+ _tied_weights_keys = ["proj_out.weight"]
2141
+
2142
+ def __init__(self, config: CustomWhisperConfig):
2143
+ super().__init__(config)
2144
+ self.model = WhisperModel(config)
2145
+ self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
2146
+ self.max_target_positions = config.max_target_positions
2147
+
2148
+ # Initialize weights and apply final processing
2149
+ self.post_init()
2150
 
2151
  def _pad_to_max_length(
2152
  current_segments,