mrprimenotes commited on
Commit
793fd0d
·
verified ·
1 Parent(s): 5993784

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +10 -2
model.py CHANGED
@@ -998,7 +998,6 @@ class WhisperEncoder(WhisperPreTrainedModel):
998
  # CUSTOM
999
  # Create conv layers dynamically based on config
1000
  self.conv_layers = nn.ModuleList()
1001
- self.conv_layers = nn.ModuleList()
1002
  for layer_config in config.conv_preprocessing_layers:
1003
  # Create sequential module for each conv+activation pair
1004
  conv_sequence = nn.Sequential(
@@ -1024,6 +1023,14 @@ class WhisperEncoder(WhisperPreTrainedModel):
1024
  # Initialize weights and apply final processing
1025
  self.post_init()
1026
 
 
 
 
 
 
 
 
 
1027
  def _freeze_parameters(self):
1028
  for param in self.parameters():
1029
  param.requires_grad = False
@@ -1101,7 +1108,8 @@ class WhisperEncoder(WhisperPreTrainedModel):
1101
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1102
  """
1103
 
1104
- expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
 
1105
 
1106
  # CUSTOM
1107
  # Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder
 
998
  # CUSTOM
999
  # Create conv layers dynamically based on config
1000
  self.conv_layers = nn.ModuleList()
 
1001
  for layer_config in config.conv_preprocessing_layers:
1002
  # Create sequential module for each conv+activation pair
1003
  conv_sequence = nn.Sequential(
 
1023
  # Initialize weights and apply final processing
1024
  self.post_init()
1025
 
1026
+ # CUSTOM
1027
+ def get_conv_stride(self):
1028
+ """Calculate total stride of all conv layers"""
1029
+ total_stride = 1
1030
+ for layer in self.conv_layers:
1031
+ total_stride *= layer.stride[0]
1032
+ return total_stride
1033
+
1034
  def _freeze_parameters(self):
1035
  for param in self.parameters():
1036
  param.requires_grad = False
 
1108
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1109
  """
1110
 
1111
+ # CUSTOM
1112
+ expected_seq_length = self.config.max_source_positions * self.get_conv_stride()
1113
 
1114
  # CUSTOM
1115
  # Must be deactivated for our purpose, theoretically Whisper supports any sequence length for the encoder