mrprimenotes commited on
Commit
5993784
·
verified ·
1 Parent(s): 9246022

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +61 -59
model.py CHANGED
@@ -1064,65 +1064,6 @@ class WhisperEncoder(WhisperPreTrainedModel):
1064
  return embed_pos[:max_pos_len]
1065
  else:
1066
  return embed_pos[-max_pos_len:]
1067
-
1068
- # CUSTOM (Monkeypatch the generation method)
1069
- def patch_generate():
1070
- """
1071
- Monkey patches the WhisperGenerationMixin to use dynamic stride calculation
1072
- """
1073
- original_generate = WhisperGenerationMixin.generate
1074
-
1075
- def get_conv_stride(self):
1076
- """Calculate total stride of all conv layers"""
1077
- total_stride = 1
1078
- for layer in self.model.encoder.conv_layers:
1079
- total_stride *= layer.stride[0]
1080
- return total_stride
1081
-
1082
- def generate_wrapper(self, *args, **kwargs):
1083
- # Store the original function logic
1084
- original_code = original_generate.__code__
1085
-
1086
- # Create a modified version of the function that uses our stride calculation
1087
- modified_code = types.CodeType(
1088
- original_code.co_argcount,
1089
- original_code.co_posonlyargcount,
1090
- original_code.co_kwonlyargcount,
1091
- original_code.co_nlocals,
1092
- original_code.co_stacksize,
1093
- original_code.co_flags,
1094
- original_code.co_code.replace(
1095
- # Replace the hardcoded stride calculation with our dynamic one
1096
- b"self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]",
1097
- b"self.get_conv_stride()",
1098
- ),
1099
- original_code.co_consts,
1100
- original_code.co_names,
1101
- original_code.co_varnames,
1102
- original_code.co_filename,
1103
- original_code.co_name,
1104
- original_code.co_firstlineno,
1105
- original_code.co_lnotab,
1106
- original_code.co_freevars,
1107
- original_code.co_cellvars,
1108
- )
1109
-
1110
- # Create a new function with the modified code
1111
- new_generate = types.FunctionType(
1112
- modified_code,
1113
- original_generate.__globals__,
1114
- original_generate.__name__,
1115
- original_generate.__defaults__,
1116
- original_generate.__closure__,
1117
- )
1118
-
1119
- # Bind the function to the instance and call it
1120
- return new_generate(self, *args, **kwargs)
1121
-
1122
- # Add the stride calculation method to the mixin
1123
- WhisperGenerationMixin.get_conv_stride = get_conv_stride
1124
- # Replace the original generate method
1125
- WhisperGenerationMixin.generate = generate_wrapper
1126
 
1127
 
1128
  def forward(
@@ -1862,9 +1803,70 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
1862
  self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
1863
  self.max_target_positions = config.max_target_positions
1864
 
 
 
1865
  # Initialize weights and apply final processing
1866
  self.post_init()
1867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1868
  def get_encoder(self):
1869
  return self.model.get_encoder()
1870
 
 
1064
  return embed_pos[:max_pos_len]
1065
  else:
1066
  return embed_pos[-max_pos_len:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1067
 
1068
 
1069
  def forward(
 
1803
  self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
1804
  self.max_target_positions = config.max_target_positions
1805
 
1806
+ self.patch_generate()
1807
+
1808
  # Initialize weights and apply final processing
1809
  self.post_init()
1810
 
1811
+ # CUSTOM (Monkeypatch the generation method)
1812
+ def patch_generate(self):
1813
+ """
1814
+ Monkey patches the WhisperGenerationMixin to use dynamic stride calculation
1815
+ """
1816
+ original_generate = WhisperGenerationMixin.generate
1817
+
1818
+ def get_conv_stride(self):
1819
+ """Calculate total stride of all conv layers"""
1820
+ total_stride = 1
1821
+ for layer in self.model.encoder.conv_layers:
1822
+ total_stride *= layer.stride[0]
1823
+ return total_stride
1824
+
1825
+ def generate_wrapper(self, *args, **kwargs):
1826
+ # Store the original function logic
1827
+ original_code = original_generate.__code__
1828
+
1829
+ # Create a modified version of the function that uses our stride calculation
1830
+ modified_code = types.CodeType(
1831
+ original_code.co_argcount,
1832
+ original_code.co_posonlyargcount,
1833
+ original_code.co_kwonlyargcount,
1834
+ original_code.co_nlocals,
1835
+ original_code.co_stacksize,
1836
+ original_code.co_flags,
1837
+ original_code.co_code.replace(
1838
+ # Replace the hardcoded stride calculation with our dynamic one
1839
+ b"self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]",
1840
+ b"self.get_conv_stride()",
1841
+ ),
1842
+ original_code.co_consts,
1843
+ original_code.co_names,
1844
+ original_code.co_varnames,
1845
+ original_code.co_filename,
1846
+ original_code.co_name,
1847
+ original_code.co_firstlineno,
1848
+ original_code.co_lnotab,
1849
+ original_code.co_freevars,
1850
+ original_code.co_cellvars,
1851
+ )
1852
+
1853
+ # Create a new function with the modified code
1854
+ new_generate = types.FunctionType(
1855
+ modified_code,
1856
+ original_generate.__globals__,
1857
+ original_generate.__name__,
1858
+ original_generate.__defaults__,
1859
+ original_generate.__closure__,
1860
+ )
1861
+
1862
+ # Bind the function to the instance and call it
1863
+ return new_generate(self, *args, **kwargs)
1864
+
1865
+ # Add the stride calculation method to the mixin
1866
+ WhisperGenerationMixin.get_conv_stride = get_conv_stride
1867
+ # Replace the original generate method
1868
+ WhisperGenerationMixin.generate = generate_wrapper
1869
+
1870
  def get_encoder(self):
1871
  return self.model.get_encoder()
1872