add "_get_attr_from_logit_processors"
Browse files
model.py
CHANGED
@@ -1846,7 +1846,14 @@ def _pad_to_max_length(
|
|
1846 |
|
1847 |
sequences = torch.stack(sequences, dim=0)
|
1848 |
return sequences
|
1849 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1850 |
# CUSTOM (patch the generation method)
|
1851 |
class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
1852 |
def generate(
|
|
|
1846 |
|
1847 |
sequences = torch.stack(sequences, dim=0)
|
1848 |
return sequences
|
1849 |
+
|
1850 |
+
def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
|
1851 |
+
if logits_processor is not None:
|
1852 |
+
logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
|
1853 |
+
if logit_processor:
|
1854 |
+
return getattr(logit_processor, attribute_name, None)
|
1855 |
+
return None
|
1856 |
+
|
1857 |
# CUSTOM (patch the generation method)
|
1858 |
class CustomWhisperGenerationMixin(WhisperGenerationMixin):
|
1859 |
def generate(
|