Update model.py
Browse files
model.py
CHANGED
@@ -998,7 +998,6 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
998 |
# CUSTOM
|
999 |
# Create conv layers dynamically based on config
|
1000 |
self.conv_layers = nn.ModuleList()
|
1001 |
-
self.conv_layers = nn.ModuleList()
|
1002 |
for layer_config in config.conv_preprocessing_layers:
|
1003 |
# Create sequential module for each conv+activation pair
|
1004 |
conv_sequence = nn.Sequential(
|
@@ -1024,6 +1023,14 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
1024 |
# Initialize weights and apply final processing
|
1025 |
self.post_init()
|
1026 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1027 |
def _freeze_parameters(self):
|
1028 |
for param in self.parameters():
|
1029 |
param.requires_grad = False
|
@@ -1101,7 +1108,8 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
1101 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1102 |
"""
|
1103 |
|
1104 |
-
|
|
|
1105 |
|
1106 |
# CUSTOM
|
1107 |
# Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder
|
|
|
998 |
# CUSTOM
|
999 |
# Create conv layers dynamically based on config
|
1000 |
self.conv_layers = nn.ModuleList()
|
|
|
1001 |
for layer_config in config.conv_preprocessing_layers:
|
1002 |
# Create sequential module for each conv+activation pair
|
1003 |
conv_sequence = nn.Sequential(
|
|
|
1023 |
# Initialize weights and apply final processing
|
1024 |
self.post_init()
|
1025 |
|
1026 |
+
# CUSTOM
|
1027 |
+
def get_conv_stride(self):
|
1028 |
+
"""Calculate total stride of all conv layers"""
|
1029 |
+
total_stride = 1
|
1030 |
+
for layer in self.conv_layers:
|
1031 |
+
total_stride *= layer.stride[0]
|
1032 |
+
return total_stride
|
1033 |
+
|
1034 |
def _freeze_parameters(self):
|
1035 |
for param in self.parameters():
|
1036 |
param.requires_grad = False
|
|
|
1108 |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
1109 |
"""
|
1110 |
|
1111 |
+
# CUSTOM
|
1112 |
+
expected_seq_length = self.config.max_source_positions * self.get_conv_stride()
|
1113 |
|
1114 |
# CUSTOM
|
1115 |
# Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder
|