christopher-hoernle
commited on
Commit
·
524f3ac
1
Parent(s):
e79974f
fix #4
Browse files
model.py
CHANGED
@@ -1032,10 +1032,9 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
1032 |
|
1033 |
# CUSTOM
|
1034 |
def get_conv_stride(self):
|
1035 |
-
"""Calculate total stride of all conv layers"""
|
1036 |
total_stride = 1
|
1037 |
for layer in self.conv_layers:
|
1038 |
-
total_stride *= layer.stride[0]
|
1039 |
return total_stride
|
1040 |
|
1041 |
def _freeze_parameters(self):
|
@@ -1116,7 +1115,7 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
1116 |
"""
|
1117 |
|
1118 |
# CUSTOM
|
1119 |
-
expected_seq_length = self.config.max_source_positions * self.get_conv_stride()
|
1120 |
|
1121 |
# CUSTOM
|
1122 |
# Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder
|
|
|
1032 |
|
1033 |
# CUSTOM
|
1034 |
def get_conv_stride(self):
|
|
|
1035 |
total_stride = 1
|
1036 |
for layer in self.conv_layers:
|
1037 |
+
total_stride *= layer[0].stride[0]
|
1038 |
return total_stride
|
1039 |
|
1040 |
def _freeze_parameters(self):
|
|
|
1115 |
"""
|
1116 |
|
1117 |
# CUSTOM
|
1118 |
+
#expected_seq_length = self.config.max_source_positions * self.get_conv_stride()
|
1119 |
|
1120 |
# CUSTOM
|
1121 |
# Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder
|