brdhaker3 commited on
Commit
0331840
·
verified ·
1 Parent(s): 1da6f48

Upload train.yaml

Browse files
Files changed (1) hide show
  1. train.yaml +174 -0
train.yaml ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ################################
2
+ # Model: wav2vec2 + DNN + CTC
3
+ # Augmentation: SpecAugment
4
+ # Authors: Titouan Parcollet 2021
5
+ # ################################
6
+
7
+ # Seed needs to be set at top of yaml, before objects with parameters are made
8
+ seed: 1234
9
+ __set_seed: !!python/object/apply:torch.manual_seed [!ref <seed>]
10
+ output_folder: !ref model/<seed>
11
+ wer_file: !ref <output_folder>/wer.txt
12
+ save_folder: !ref <output_folder>/save
13
+ train_log: !ref <output_folder>/train_log.txt
14
+
15
+ # URL for the biggest LeBenchmark wav2vec french.
16
+ wav2vec2_folder: !ref <save_folder>/wav2vec2_checkpoint
17
+
18
+ # Data files
19
+ data_folder: /path/to/data # e.g, /localscratch/cv-corpus-5.1-2020-06-22/fr
20
+ train_tsv_file: !ref <data_folder>/train.tsv # Standard CommonVoice .tsv files
21
+ dev_tsv_file: !ref <data_folder>/dev.tsv # Standard CommonVoice .tsv files
22
+ test_tsv_file: !ref <data_folder>/test.tsv # Standard CommonVoice .tsv files
23
+ accented_letters: True
24
+ language: fr # use 'it' for Italian, 'rw' for Kinyarwanda, 'en' for english
25
+ train_csv: Data/train_wavs/train.csv
26
+ valid_csv: Data/dev_wavs/dev.csv
27
+ test_csv:
28
+ - Data/test_wavs/test.csv
29
+
30
+ skip_prep: True # Skip data preparation
31
+
32
+ use_language_modelling: True
33
+ ngram_lm_path: languageModel.arpa
34
+
35
+ # We remove utterance slonger than 10s in the train/dev/test sets as
36
+ # longer sentences certainly correspond to "open microphones".
37
+ avoid_if_longer_than: 10.0
38
+ avoid_if_shorter_than: 1.2
39
+
40
+
41
+ # Training parameters
42
+ number_of_epochs: 12
43
+ lr: 1.0
44
+ lr_wav2vec: 0.0001
45
+ sorting: ascending
46
+ auto_mix_prec: False
47
+ sample_rate: 16000
48
+ ckpt_interval_minutes: 30 # save checkpoint every N min
49
+
50
+ # With data_parallel batch_size is split into N jobs
51
+ # With DDP batch_size is multiplied by N jobs
52
+ # Must be 6 per GPU to fit 16GB of VRAM
53
+ batch_size: 10
54
+ test_batch_size: 4
55
+
56
+ dataloader_options:
57
+ batch_size: !ref <batch_size>
58
+ num_workers: 6
59
+ test_dataloader_options:
60
+ batch_size: !ref <test_batch_size>
61
+ num_workers: 6
62
+
63
+ # BPE parameters
64
+ token_type: char # ["unigram", "bpe", "char"]
65
+ character_coverage: 1.0
66
+
67
+ # Model parameters
68
+ # activation: !name:torch.nn.LeakyReLU
69
+ wav2vec_output_dim: 1024
70
+ dnn_neurons: 1024
71
+ freeze_wav2vec: False
72
+ freeze_feature_extractor: True
73
+ dropout: 0.15
74
+ warmup_steps: 500 # The wav2vec 2 model isn't updated for this amount of steps
75
+
76
+ # Outputs
77
+ output_neurons: 40 # BPE size, index(blank/eos/bos) = 0
78
+
79
+ # Decoding parameters
80
+ # Be sure that the bos and eos index match with the BPEs ones
81
+ blank_index: 0
82
+ unk_index: 1
83
+
84
+ #
85
+ # Functions and classes
86
+ #
87
+ epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
88
+ limit: !ref <number_of_epochs>
89
+
90
+
91
+ enc: !new:speechbrain.nnet.containers.Sequential
92
+ input_shape: [null, null, !ref <wav2vec_output_dim>]
93
+ linear1: !name:speechbrain.nnet.linear.Linear
94
+ n_neurons: !ref <dnn_neurons>
95
+ bias: True
96
+ bn1: !name:speechbrain.nnet.normalization.BatchNorm1d
97
+ activation: !new:torch.nn.LeakyReLU
98
+ drop: !new:torch.nn.Dropout
99
+ p: !ref <dropout>
100
+ linear2: !name:speechbrain.nnet.linear.Linear
101
+ n_neurons: !ref <dnn_neurons>
102
+ bias: True
103
+ bn2: !name:speechbrain.nnet.normalization.BatchNorm1d
104
+ activation2: !new:torch.nn.LeakyReLU
105
+ drop2: !new:torch.nn.Dropout
106
+ p: !ref <dropout>
107
+ linear3: !name:speechbrain.nnet.linear.Linear
108
+ n_neurons: !ref <dnn_neurons>
109
+ bias: True
110
+ bn3: !name:speechbrain.nnet.normalization.BatchNorm1d
111
+ activation3: !new:torch.nn.LeakyReLU
112
+
113
+ wav2vec2: !new:speechbrain.lobes.models.huggingface_transformers.wav2vec2.Wav2Vec2
114
+ source: wavlm-large/
115
+ output_norm: False
116
+ freeze: !ref <freeze_wav2vec>
117
+ freeze_feature_extractor: !ref <freeze_feature_extractor>
118
+ save_path: !ref <wav2vec2_folder>
119
+
120
+
121
+ ctc_lin: !new:speechbrain.nnet.linear.Linear
122
+ input_size: !ref <dnn_neurons>
123
+ n_neurons: !ref <output_neurons>
124
+
125
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
126
+ apply_log: True
127
+
128
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
129
+ blank_index: !ref <blank_index>
130
+
131
+ modules:
132
+ wav2vec2: !ref <wav2vec2>
133
+ enc: !ref <enc>
134
+ ctc_lin: !ref <ctc_lin>
135
+
136
+ model: !new:torch.nn.ModuleList
137
+ - [!ref <enc>, !ref <ctc_lin>]
138
+
139
+ model_opt_class: !name:torch.optim.Adadelta
140
+ lr: !ref <lr>
141
+ rho: 0.95
142
+ eps: 1.e-8
143
+
144
+ wav2vec_opt_class: !name:torch.optim.Adam
145
+ lr: !ref <lr_wav2vec>
146
+
147
+ lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
148
+ initial_value: !ref <lr>
149
+ improvement_threshold: 0.0025
150
+ annealing_factor: 0.8
151
+ patient: 0
152
+
153
+ lr_annealing_wav2vec: !new:speechbrain.nnet.schedulers.NewBobScheduler
154
+ initial_value: !ref <lr_wav2vec>
155
+ improvement_threshold: 0.0025
156
+ annealing_factor: 0.9
157
+ patient: 0
158
+
159
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
160
+ checkpoints_dir: !ref <save_folder>
161
+ recoverables:
162
+ wav2vec2: !ref <wav2vec2>
163
+ model: !ref <model>
164
+ scheduler_model: !ref <lr_annealing_model>
165
+ scheduler_wav2vec: !ref <lr_annealing_wav2vec>
166
+ counter: !ref <epoch_counter>
167
+
168
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
169
+ save_file: !ref <train_log>
170
+
171
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
172
+
173
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
174
+ split_tokens: True