Update model.py
Browse files
model.py
CHANGED
@@ -3,24 +3,6 @@ from typing import List, Literal, Optional
|
|
3 |
import types
|
4 |
|
5 |
"""Custom config to support modification of the Whisper encoder."""
|
6 |
-
|
7 |
-
class ConvLayerConfig:
|
8 |
-
"""Configuration for a single convolutional layer"""
|
9 |
-
def __init__(
|
10 |
-
self,
|
11 |
-
in_channels: int,
|
12 |
-
out_channels: int,
|
13 |
-
kernel_size: int,
|
14 |
-
stride: int = 1,
|
15 |
-
padding: int = 0,
|
16 |
-
activation: Literal["gelu", "relu", "none"] = "gelu"
|
17 |
-
):
|
18 |
-
self.in_channels = in_channels
|
19 |
-
self.out_channels = out_channels
|
20 |
-
self.kernel_size = kernel_size
|
21 |
-
self.stride = stride
|
22 |
-
self.padding = padding
|
23 |
-
self.activation = activation
|
24 |
|
25 |
class CustomWhisperConfig(WhisperConfig):
|
26 |
def __init__(
|
@@ -43,19 +25,22 @@ class CustomWhisperConfig(WhisperConfig):
|
|
43 |
|
44 |
if conv_preprocessing_layers is None:
|
45 |
conv_preprocessing_layers = [
|
46 |
-
|
47 |
-
in_channels
|
48 |
-
out_channels
|
49 |
-
kernel_size
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
]
|
60 |
|
61 |
self.conv_preprocessing_layers = conv_preprocessing_layers
|
|
|
3 |
import types
|
4 |
|
5 |
"""Custom config to support modification of the Whisper encoder."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
class CustomWhisperConfig(WhisperConfig):
|
8 |
def __init__(
|
|
|
25 |
|
26 |
if conv_preprocessing_layers is None:
|
27 |
conv_preprocessing_layers = [
|
28 |
+
{
|
29 |
+
"in_channels": self.num_mel_bins,
|
30 |
+
"out_channels": self.d_model,
|
31 |
+
"kernel_size": 3,
|
32 |
+
"stride": 1,
|
33 |
+
"padding": 1,
|
34 |
+
"activation": "gelu"
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"in_channels": self.d_model,
|
38 |
+
"out_channels": self.d_model,
|
39 |
+
"kernel_size": 3,
|
40 |
+
"stride": 2,
|
41 |
+
"padding": 1,
|
42 |
+
"activation": "gelu"
|
43 |
+
}
|
44 |
]
|
45 |
|
46 |
self.conv_preprocessing_layers = conv_preprocessing_layers
|