mrprimenotes commited on
Commit
4036053
·
verified ·
1 Parent(s): fa1ca6d

add "_get_attr_from_logit_processors"

Browse files
Files changed (1) hide show
  1. model.py +8 -1
model.py CHANGED
@@ -1846,7 +1846,14 @@ def _pad_to_max_length(
1846
 
1847
  sequences = torch.stack(sequences, dim=0)
1848
  return sequences
1849
-
 
 
 
 
 
 
 
1850
  # CUSTOM (patch the generation method)
1851
  class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1852
  def generate(
 
1846
 
1847
  sequences = torch.stack(sequences, dim=0)
1848
  return sequences
1849
+
1850
+ def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name):
1851
+ if logits_processor is not None:
1852
+ logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None)
1853
+ if logit_processor:
1854
+ return getattr(logit_processor, attribute_name, None)
1855
+ return None
1856
+
1857
  # CUSTOM (patch the generation method)
1858
  class CustomWhisperGenerationMixin(WhisperGenerationMixin):
1859
  def generate(