Update model.py
Browse files
model.py
CHANGED
@@ -31,7 +31,8 @@ class CustomWhisperConfig(WhisperConfig):
|
|
31 |
"kernel_size": 3,
|
32 |
"stride": 1,
|
33 |
"padding": 1,
|
34 |
-
"activation": "gelu"
|
|
|
35 |
},
|
36 |
{
|
37 |
"in_channels": self.d_model,
|
@@ -39,7 +40,8 @@ class CustomWhisperConfig(WhisperConfig):
|
|
39 |
"kernel_size": 3,
|
40 |
"stride": 2,
|
41 |
"padding": 1,
|
42 |
-
"activation": "gelu"
|
|
|
43 |
}
|
44 |
]
|
45 |
|
@@ -996,16 +998,21 @@ class WhisperEncoder(WhisperPreTrainedModel):
|
|
996 |
# CUSTOM
|
997 |
# Create conv layers dynamically based on config
|
998 |
self.conv_layers = nn.ModuleList()
|
|
|
999 |
for layer_config in config.conv_preprocessing_layers:
|
1000 |
-
|
1001 |
-
|
1002 |
-
|
1003 |
-
|
1004 |
-
|
1005 |
-
|
1006 |
-
|
1007 |
-
|
1008 |
-
|
|
|
|
|
|
|
|
|
1009 |
|
1010 |
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
1011 |
self.embed_positions.requires_grad_(False)
|
|
|
31 |
"kernel_size": 3,
|
32 |
"stride": 1,
|
33 |
"padding": 1,
|
34 |
+
"activation": "gelu",
|
35 |
+
"bias": True
|
36 |
},
|
37 |
{
|
38 |
"in_channels": self.d_model,
|
|
|
40 |
"kernel_size": 3,
|
41 |
"stride": 2,
|
42 |
"padding": 1,
|
43 |
+
"activation": "gelu",
|
44 |
+
"bias": True
|
45 |
}
|
46 |
]
|
47 |
|
|
|
998 |
# CUSTOM
|
999 |
# Create conv layers dynamically based on config
|
1000 |
self.conv_layers = nn.ModuleList()
|
1001 |
+
self.conv_layers = nn.ModuleList()
|
1002 |
for layer_config in config.conv_preprocessing_layers:
|
1003 |
+
# Create sequential module for each conv+activation pair
|
1004 |
+
conv_sequence = nn.Sequential(
|
1005 |
+
nn.Conv1d(
|
1006 |
+
layer_config["in_channels"],
|
1007 |
+
layer_config["out_channels"],
|
1008 |
+
kernel_size=layer_config["kernel_size"],
|
1009 |
+
stride=layer_config["stride"],
|
1010 |
+
padding=layer_config["padding"],
|
1011 |
+
bias=True
|
1012 |
+
),
|
1013 |
+
nn.GELU() if layer_config["activation"] == "gelu" else nn.ReLU()
|
1014 |
+
)
|
1015 |
+
self.conv_layers.append(conv_sequence)
|
1016 |
|
1017 |
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
1018 |
self.embed_positions.requires_grad_(False)
|