Update model.py
Browse files
model.py
CHANGED
@@ -1064,65 +1064,6 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
1064 |
return embed_pos[:max_pos_len]
|
1065 |
else:
|
1066 |
return embed_pos[-max_pos_len:]
|
1067 |
-
|
1068 |
-
# CUSTOM (Monkeypatch the generation method)
|
1069 |
-
def patch_generate():
|
1070 |
-
"""
|
1071 |
-
Monkey patches the WhisperGenerationMixin to use dynamic stride calculation
|
1072 |
-
"""
|
1073 |
-
original_generate = WhisperGenerationMixin.generate
|
1074 |
-
|
1075 |
-
def get_conv_stride(self):
|
1076 |
-
"""Calculate total stride of all conv layers"""
|
1077 |
-
total_stride = 1
|
1078 |
-
for layer in self.model.encoder.conv_layers:
|
1079 |
-
total_stride *= layer.stride[0]
|
1080 |
-
return total_stride
|
1081 |
-
|
1082 |
-
def generate_wrapper(self, *args, **kwargs):
|
1083 |
-
# Store the original function logic
|
1084 |
-
original_code = original_generate.__code__
|
1085 |
-
|
1086 |
-
# Create a modified version of the function that uses our stride calculation
|
1087 |
-
modified_code = types.CodeType(
|
1088 |
-
original_code.co_argcount,
|
1089 |
-
original_code.co_posonlyargcount,
|
1090 |
-
original_code.co_kwonlyargcount,
|
1091 |
-
original_code.co_nlocals,
|
1092 |
-
original_code.co_stacksize,
|
1093 |
-
original_code.co_flags,
|
1094 |
-
original_code.co_code.replace(
|
1095 |
-
# Replace the hardcoded stride calculation with our dynamic one
|
1096 |
-
b"self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]",
|
1097 |
-
b"self.get_conv_stride()",
|
1098 |
-
),
|
1099 |
-
original_code.co_consts,
|
1100 |
-
original_code.co_names,
|
1101 |
-
original_code.co_varnames,
|
1102 |
-
original_code.co_filename,
|
1103 |
-
original_code.co_name,
|
1104 |
-
original_code.co_firstlineno,
|
1105 |
-
original_code.co_lnotab,
|
1106 |
-
original_code.co_freevars,
|
1107 |
-
original_code.co_cellvars,
|
1108 |
-
)
|
1109 |
-
|
1110 |
-
# Create a new function with the modified code
|
1111 |
-
new_generate = types.FunctionType(
|
1112 |
-
modified_code,
|
1113 |
-
original_generate.__globals__,
|
1114 |
-
original_generate.__name__,
|
1115 |
-
original_generate.__defaults__,
|
1116 |
-
original_generate.__closure__,
|
1117 |
-
)
|
1118 |
-
|
1119 |
-
# Bind the function to the instance and call it
|
1120 |
-
return new_generate(self, *args, **kwargs)
|
1121 |
-
|
1122 |
-
# Add the stride calculation method to the mixin
|
1123 |
-
WhisperGenerationMixin.get_conv_stride = get_conv_stride
|
1124 |
-
# Replace the original generate method
|
1125 |
-
WhisperGenerationMixin.generate = generate_wrapper
|
1126 |
|
1127 |
|
1128 |
def forward(
|
@@ -1862,9 +1803,70 @@ class CustomWhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTr
|
|
1862 |
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1863 |
self.max_target_positions = config.max_target_positions
|
1864 |
|
|
|
|
|
1865 |
# Initialize weights and apply final processing
|
1866 |
self.post_init()
|
1867 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1868 |
def get_encoder(self):
|
1869 |
return self.model.get_encoder()
|
1870 |
|
|
|
1064 |
return embed_pos[:max_pos_len]
|
1065 |
else:
|
1066 |
return embed_pos[-max_pos_len:]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1067 |
|
1068 |
|
1069 |
def forward(
|
|
|
1803 |
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
1804 |
self.max_target_positions = config.max_target_positions
|
1805 |
|
1806 |
+
self.patch_generate()
|
1807 |
+
|
1808 |
# Initialize weights and apply final processing
|
1809 |
self.post_init()
|
1810 |
|
1811 |
+
# CUSTOM (Monkeypatch the generation method)
|
1812 |
+
def patch_generate(self):
|
1813 |
+
"""
|
1814 |
+
Monkey patches the WhisperGenerationMixin to use dynamic stride calculation
|
1815 |
+
"""
|
1816 |
+
original_generate = WhisperGenerationMixin.generate
|
1817 |
+
|
1818 |
+
def get_conv_stride(self):
|
1819 |
+
"""Calculate total stride of all conv layers"""
|
1820 |
+
total_stride = 1
|
1821 |
+
for layer in self.model.encoder.conv_layers:
|
1822 |
+
total_stride *= layer.stride[0]
|
1823 |
+
return total_stride
|
1824 |
+
|
1825 |
+
def generate_wrapper(self, *args, **kwargs):
|
1826 |
+
# Store the original function logic
|
1827 |
+
original_code = original_generate.__code__
|
1828 |
+
|
1829 |
+
# Create a modified version of the function that uses our stride calculation
|
1830 |
+
modified_code = types.CodeType(
|
1831 |
+
original_code.co_argcount,
|
1832 |
+
original_code.co_posonlyargcount,
|
1833 |
+
original_code.co_kwonlyargcount,
|
1834 |
+
original_code.co_nlocals,
|
1835 |
+
original_code.co_stacksize,
|
1836 |
+
original_code.co_flags,
|
1837 |
+
original_code.co_code.replace(
|
1838 |
+
# Replace the hardcoded stride calculation with our dynamic one
|
1839 |
+
b"self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]",
|
1840 |
+
b"self.get_conv_stride()",
|
1841 |
+
),
|
1842 |
+
original_code.co_consts,
|
1843 |
+
original_code.co_names,
|
1844 |
+
original_code.co_varnames,
|
1845 |
+
original_code.co_filename,
|
1846 |
+
original_code.co_name,
|
1847 |
+
original_code.co_firstlineno,
|
1848 |
+
original_code.co_lnotab,
|
1849 |
+
original_code.co_freevars,
|
1850 |
+
original_code.co_cellvars,
|
1851 |
+
)
|
1852 |
+
|
1853 |
+
# Create a new function with the modified code
|
1854 |
+
new_generate = types.FunctionType(
|
1855 |
+
modified_code,
|
1856 |
+
original_generate.__globals__,
|
1857 |
+
original_generate.__name__,
|
1858 |
+
original_generate.__defaults__,
|
1859 |
+
original_generate.__closure__,
|
1860 |
+
)
|
1861 |
+
|
1862 |
+
# Bind the function to the instance and call it
|
1863 |
+
return new_generate(self, *args, **kwargs)
|
1864 |
+
|
1865 |
+
# Add the stride calculation method to the mixin
|
1866 |
+
WhisperGenerationMixin.get_conv_stride = get_conv_stride
|
1867 |
+
# Replace the original generate method
|
1868 |
+
WhisperGenerationMixin.generate = generate_wrapper
|
1869 |
+
|
1870 |
def get_encoder(self):
|
1871 |
return self.model.get_encoder()
|
1872 |
|