mrprimenotes commited on
Commit
81863a0
·
verified ·
1 Parent(s): 3b51d88

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +16 -31
model.py CHANGED
@@ -3,24 +3,6 @@ from typing import List, Literal, Optional
3
  import types
4
 
5
  """Custom config to support modification of the Whisper encoder."""
6
-
7
- class ConvLayerConfig:
8
- """Configuration for a single convolutional layer"""
9
- def __init__(
10
- self,
11
- in_channels: int,
12
- out_channels: int,
13
- kernel_size: int,
14
- stride: int = 1,
15
- padding: int = 0,
16
- activation: Literal["gelu", "relu", "none"] = "gelu"
17
- ):
18
- self.in_channels = in_channels
19
- self.out_channels = out_channels
20
- self.kernel_size = kernel_size
21
- self.stride = stride
22
- self.padding = padding
23
- self.activation = activation
24
 
25
  class CustomWhisperConfig(WhisperConfig):
26
  def __init__(
@@ -43,19 +25,22 @@ class CustomWhisperConfig(WhisperConfig):
43
 
44
  if conv_preprocessing_layers is None:
45
  conv_preprocessing_layers = [
46
- ConvLayerConfig(
47
- in_channels=self.num_mel_bins,
48
- out_channels=self.d_model,
49
- kernel_size=3,
50
- padding=1
51
- ),
52
- ConvLayerConfig(
53
- in_channels=self.d_model,
54
- out_channels=self.d_model,
55
- kernel_size=3,
56
- stride=2,
57
- padding=1
58
- )
 
 
 
59
  ]
60
 
61
  self.conv_preprocessing_layers = conv_preprocessing_layers
 
3
  import types
4
 
5
  """Custom config to support modification of the Whisper encoder."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  class CustomWhisperConfig(WhisperConfig):
8
  def __init__(
 
25
 
26
  if conv_preprocessing_layers is None:
27
  conv_preprocessing_layers = [
28
+ {
29
+ "in_channels": self.num_mel_bins,
30
+ "out_channels": self.d_model,
31
+ "kernel_size": 3,
32
+ "stride": 1,
33
+ "padding": 1,
34
+ "activation": "gelu"
35
+ },
36
+ {
37
+ "in_channels": self.d_model,
38
+ "out_channels": self.d_model,
39
+ "kernel_size": 3,
40
+ "stride": 2,
41
+ "padding": 1,
42
+ "activation": "gelu"
43
+ }
44
  ]
45
 
46
  self.conv_preprocessing_layers = conv_preprocessing_layers