update patch for generation method
Browse files
model.py
CHANGED
@@ -1801,25 +1801,13 @@ class WhisperModel(WhisperPreTrainedModel):
|
|
1801 |
encoder_attentions=encoder_outputs.attentions,
|
1802 |
)
|
1803 |
|
|
|
1804 |
|
1805 |
-
|
1806 |
-
|
1807 |
-
|
1808 |
-
)
|
1809 |
-
|
1810 |
-
base_model_prefix = "model"
|
1811 |
-
_tied_weights_keys = ["proj_out.weight"]
|
1812 |
-
|
1813 |
-
def __init__(self, config: CustomWhisperConfig):
|
1814 |
-
super().__init__(config)
|
1815 |
-
self.model = WhisperModel(config)
|
1816 |
-
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1817 |
-
self.max_target_positions = config.max_target_positions
|
1818 |
-
|
1819 |
-
# Initialize weights and apply final processing
|
1820 |
-
self.post_init()
|
1821 |
-
|
1822 |
-
# CUSTOM (patch the generation method)
|
1823 |
def generate(
|
1824 |
self,
|
1825 |
input_features: Optional[torch.Tensor] = None,
|
@@ -2142,6 +2130,23 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
|
|
2142 |
for i in range(len(outputs.encoder_hidden_states))
|
2143 |
)
|
2144 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2145 |
|
2146 |
def _pad_to_max_length(
|
2147 |
current_segments,
|
|
|
1801 |
encoder_attentions=encoder_outputs.attentions,
|
1802 |
)
|
1803 |
|
1804 |
+
|
1805 |
|
1806 |
+
# CUSTOM (patch the generation method)
|
1807 |
+
class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
1808 |
+
def __init__(self, *args, **kwargs):
|
1809 |
+
super().__init__(*args, **kwargs)
|
1810 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1811 |
def generate(
|
1812 |
self,
|
1813 |
input_features: Optional[torch.Tensor] = None,
|
|
|
2130 |
for i in range(len(outputs.encoder_hidden_states))
|
2131 |
)
|
2132 |
return outputs
|
2133 |
+
|
2134 |
+
@add_start_docstrings(
|
2135 |
+
"The Whisper Model with a language modeling head. Can be used for automatic speech recognition.",
|
2136 |
+
WHISPER_START_DOCSTRING,
|
2137 |
+
)
|
2138 |
+
class CustomWhisperForConditionalGeneration(CustomWhisperGenerationMixin, WhisperPreTrainedModel):
|
2139 |
+
base_model_prefix = "model"
|
2140 |
+
_tied_weights_keys = ["proj_out.weight"]
|
2141 |
+
|
2142 |
+
def __init__(self, config: CustomWhisperConfig):
|
2143 |
+
super().__init__(config)
|
2144 |
+
self.model = WhisperModel(config)
|
2145 |
+
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
2146 |
+
self.max_target_positions = config.max_target_positions
|
2147 |
+
|
2148 |
+
# Initialize weights and apply final processing
|
2149 |
+
self.post_init()
|
2150 |
|
2151 |
def _pad_to_max_length(
|
2152 |
current_segments,
|