pfluo csukuangfj commited on
Commit
138f2ef
·
1 Parent(s): 2568b7f

Upload torch.jit.trace() exported files (#1)

Browse files

- Upload torch.jit.trace() exported files (af38164286f600a2da1bfb74da280525869f785d)


Co-authored-by: fangjun <[email protected]>

exp/decoder_jit_trace.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e3a51b423148a03481155e7785dba05d3eda920749fd940a81a2decde30a510
3
+ size 12830070
exp/encoder_jit_trace.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95344ad3309c566bde4f0582f29a88037ba6b2ca518584f84011f281d7677def
3
+ size 330440074
exp/jit_trace_export-zh.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ """
4
+ Usage:
5
+ ./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
6
+ --exp-dir $dir/exp \
7
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
8
+ --lang-dir ./data/lang_char_bpe \
9
+ --epoch 99 \
10
+ --avg 1 \
11
+ --use-averaged-model 0 \
12
+ \
13
+ --decode-chunk-len 32 \
14
+ --num-encoder-layers "2,4,3,2,4" \
15
+ --feedforward-dims "1024,1024,1536,1536,1024" \
16
+ --nhead "8,8,8,8,8" \
17
+ --encoder-dims "384,384,384,384,384" \
18
+ --attention-dims "192,192,192,192,192" \
19
+ --encoder-unmasked-dims "256,256,256,256,256" \
20
+ --zipformer-downsampling-factors "1,2,4,8,2" \
21
+ --cnn-module-kernels "31,31,31,31,31" \
22
+ --decoder-dim 512 \
23
+ --joiner-dim 512
24
+ """
25
+
26
+ import argparse
27
+ import logging
28
+ from pathlib import Path
29
+
30
+ import sentencepiece as spm
31
+ import torch
32
+ from scaling_converter import convert_scaled_to_non_scaled
33
+ from train import add_model_arguments, get_params, get_transducer_model
34
+ from icefall.lexicon import Lexicon
35
+
36
+ from icefall.checkpoint import (
37
+ average_checkpoints,
38
+ average_checkpoints_with_averaged_model,
39
+ find_checkpoints,
40
+ load_checkpoint,
41
+ )
42
+ from icefall.utils import AttributeDict, str2bool
43
+
44
+
45
+ def get_parser():
46
+ parser = argparse.ArgumentParser(
47
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
48
+ )
49
+
50
+ parser.add_argument(
51
+ "--epoch",
52
+ type=int,
53
+ default=28,
54
+ help="""It specifies the checkpoint to use for averaging.
55
+ Note: Epoch counts from 0.
56
+ You can specify --avg to use more checkpoints for model averaging.""",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--iter",
61
+ type=int,
62
+ default=0,
63
+ help="""If positive, --epoch is ignored and it
64
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
65
+ You can specify --avg to use more checkpoints for model averaging.
66
+ """,
67
+ )
68
+
69
+ parser.add_argument(
70
+ "--avg",
71
+ type=int,
72
+ default=15,
73
+ help="Number of checkpoints to average. Automatically select "
74
+ "consecutive checkpoints before the checkpoint specified by "
75
+ "'--epoch' and '--iter'",
76
+ )
77
+
78
+ parser.add_argument(
79
+ "--exp-dir",
80
+ type=str,
81
+ default="pruned_transducer_stateless2/exp",
82
+ help="""It specifies the directory where all training related
83
+ files, e.g., checkpoints, log, etc, are saved
84
+ """,
85
+ )
86
+
87
+ parser.add_argument(
88
+ "--lang-dir",
89
+ type=str,
90
+ default="data/lang_char",
91
+ help="The lang dir",
92
+ )
93
+
94
+ parser.add_argument(
95
+ "--context-size",
96
+ type=int,
97
+ default=2,
98
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
99
+ )
100
+
101
+ parser.add_argument(
102
+ "--use-averaged-model",
103
+ type=str2bool,
104
+ default=True,
105
+ help="Whether to load averaged model. Currently it only supports "
106
+ "using --epoch. If True, it would decode with the averaged model "
107
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
108
+ "Actually only the models with epoch number of `epoch-avg` and "
109
+ "`epoch` are loaded for averaging. ",
110
+ )
111
+
112
+ add_model_arguments(parser)
113
+
114
+ return parser
115
+
116
+
117
+ def export_encoder_model_jit_trace(
118
+ encoder_model: torch.nn.Module,
119
+ encoder_filename: str,
120
+ params: AttributeDict,
121
+ ) -> None:
122
+ """Export the given encoder model with torch.jit.trace()
123
+
124
+ Note: The warmup argument is fixed to 1.
125
+
126
+ Args:
127
+ encoder_model:
128
+ The input encoder model
129
+ encoder_filename:
130
+ The filename to save the exported model.
131
+ """
132
+ decode_chunk_len = params.decode_chunk_len # before subsampling
133
+ pad_length = 7
134
+ s = f"decode_chunk_len: {decode_chunk_len}"
135
+ logging.info(s)
136
+ assert encoder_model.decode_chunk_size == decode_chunk_len // 2, (
137
+ encoder_model.decode_chunk_size,
138
+ decode_chunk_len,
139
+ )
140
+
141
+ T = decode_chunk_len + pad_length
142
+
143
+ x = torch.zeros(1, T, 80, dtype=torch.float32)
144
+ x_lens = torch.full((1,), T, dtype=torch.int32)
145
+ states = encoder_model.get_init_state(device=x.device)
146
+
147
+ encoder_model.__class__.forward = encoder_model.__class__.streaming_forward
148
+ traced_model = torch.jit.trace(encoder_model, (x, x_lens, states))
149
+ traced_model.save(encoder_filename)
150
+ logging.info(f"Saved to {encoder_filename}")
151
+
152
+
153
+ def export_decoder_model_jit_trace(
154
+ decoder_model: torch.nn.Module,
155
+ decoder_filename: str,
156
+ ) -> None:
157
+ """Export the given decoder model with torch.jit.trace()
158
+
159
+ Note: The argument need_pad is fixed to False.
160
+
161
+ Args:
162
+ decoder_model:
163
+ The input decoder model
164
+ decoder_filename:
165
+ The filename to save the exported model.
166
+ """
167
+ y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64)
168
+ need_pad = torch.tensor([False])
169
+
170
+ traced_model = torch.jit.trace(decoder_model, (y, need_pad))
171
+ traced_model.save(decoder_filename)
172
+ logging.info(f"Saved to {decoder_filename}")
173
+
174
+
175
+ def export_joiner_model_jit_trace(
176
+ joiner_model: torch.nn.Module,
177
+ joiner_filename: str,
178
+ ) -> None:
179
+ """Export the given joiner model with torch.jit.trace()
180
+
181
+ Note: The argument project_input is fixed to True. A user should not
182
+ project the encoder_out/decoder_out by himself/herself. The exported joiner
183
+ will do that for the user.
184
+
185
+ Args:
186
+ joiner_model:
187
+ The input joiner model
188
+ joiner_filename:
189
+ The filename to save the exported model.
190
+
191
+ """
192
+ encoder_out_dim = joiner_model.encoder_proj.weight.shape[1]
193
+ decoder_out_dim = joiner_model.decoder_proj.weight.shape[1]
194
+ encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
195
+ decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
196
+
197
+ traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out))
198
+ traced_model.save(joiner_filename)
199
+ logging.info(f"Saved to {joiner_filename}")
200
+
201
+
202
+ @torch.no_grad()
203
+ def main():
204
+ args = get_parser().parse_args()
205
+ args.exp_dir = Path(args.exp_dir)
206
+
207
+ params = get_params()
208
+ params.update(vars(args))
209
+
210
+ device = torch.device("cpu")
211
+
212
+ logging.info(f"device: {device}")
213
+
214
+ lexicon = Lexicon(params.lang_dir)
215
+ params.blank_id = 0
216
+ params.vocab_size = max(lexicon.tokens) + 1
217
+
218
+ logging.info(params)
219
+
220
+ logging.info("About to create model")
221
+ model = get_transducer_model(params)
222
+
223
+ if not params.use_averaged_model:
224
+ if params.iter > 0:
225
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
226
+ : params.avg
227
+ ]
228
+ if len(filenames) == 0:
229
+ raise ValueError(
230
+ f"No checkpoints found for"
231
+ f" --iter {params.iter}, --avg {params.avg}"
232
+ )
233
+ elif len(filenames) < params.avg:
234
+ raise ValueError(
235
+ f"Not enough checkpoints ({len(filenames)}) found for"
236
+ f" --iter {params.iter}, --avg {params.avg}"
237
+ )
238
+ logging.info(f"averaging {filenames}")
239
+ model.to(device)
240
+ model.load_state_dict(average_checkpoints(filenames, device=device))
241
+ elif params.avg == 1:
242
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
243
+ else:
244
+ start = params.epoch - params.avg + 1
245
+ filenames = []
246
+ for i in range(start, params.epoch + 1):
247
+ if i >= 1:
248
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
249
+ logging.info(f"averaging {filenames}")
250
+ model.to(device)
251
+ model.load_state_dict(average_checkpoints(filenames, device=device))
252
+ else:
253
+ if params.iter > 0:
254
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
255
+ : params.avg + 1
256
+ ]
257
+ if len(filenames) == 0:
258
+ raise ValueError(
259
+ f"No checkpoints found for"
260
+ f" --iter {params.iter}, --avg {params.avg}"
261
+ )
262
+ elif len(filenames) < params.avg + 1:
263
+ raise ValueError(
264
+ f"Not enough checkpoints ({len(filenames)}) found for"
265
+ f" --iter {params.iter}, --avg {params.avg}"
266
+ )
267
+ filename_start = filenames[-1]
268
+ filename_end = filenames[0]
269
+ logging.info(
270
+ "Calculating the averaged model over iteration checkpoints"
271
+ f" from {filename_start} (excluded) to {filename_end}"
272
+ )
273
+ model.to(device)
274
+ model.load_state_dict(
275
+ average_checkpoints_with_averaged_model(
276
+ filename_start=filename_start,
277
+ filename_end=filename_end,
278
+ device=device,
279
+ )
280
+ )
281
+ else:
282
+ assert params.avg > 0, params.avg
283
+ start = params.epoch - params.avg
284
+ assert start >= 1, start
285
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
286
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
287
+ logging.info(
288
+ f"Calculating the averaged model over epoch range from "
289
+ f"{start} (excluded) to {params.epoch}"
290
+ )
291
+ model.to(device)
292
+ model.load_state_dict(
293
+ average_checkpoints_with_averaged_model(
294
+ filename_start=filename_start,
295
+ filename_end=filename_end,
296
+ device=device,
297
+ )
298
+ )
299
+
300
+ model.to("cpu")
301
+ model.eval()
302
+
303
+ convert_scaled_to_non_scaled(model, inplace=True)
304
+ logging.info("Using torch.jit.trace()")
305
+
306
+ logging.info("Exporting encoder")
307
+ encoder_filename = params.exp_dir / "encoder_jit_trace.pt"
308
+ export_encoder_model_jit_trace(model.encoder, encoder_filename, params)
309
+
310
+ logging.info("Exporting decoder")
311
+ decoder_filename = params.exp_dir / "decoder_jit_trace.pt"
312
+ export_decoder_model_jit_trace(model.decoder, decoder_filename)
313
+
314
+ logging.info("Exporting joiner")
315
+ joiner_filename = params.exp_dir / "joiner_jit_trace.pt"
316
+ export_joiner_model_jit_trace(model.joiner, joiner_filename)
317
+
318
+
319
+ if __name__ == "__main__":
320
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
321
+
322
+ logging.basicConfig(format=formatter, level=logging.INFO)
323
+ main()
exp/jit_trace_export-zh.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Please go to
4
+ # https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05
5
+ # to download the pre-trained models
6
+
7
+ #
8
+ # cd $dir
9
+ # ln -s icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05/exp/pretrained-epoch-30-avg-10-averaged.pt epoch-30.pt
10
+
11
+ . path.sh
12
+
13
+ export CUDA_VISIBLE_DEVICES=""
14
+ set -ex
15
+
16
+ dir=./k2fsa-zipformer-chinese-english-mixed
17
+ if [ ! -f $dir/exp/epoch-99.pt ]; then
18
+ pushd $dir/exp
19
+ ln -s pretrained.pt epoch-99.pt
20
+ popd
21
+ fi
22
+
23
+ ./pruned_transducer_stateless7_streaming/jit_trace_export-zh.py \
24
+ --exp-dir $dir/exp \
25
+ --lang-dir $dir/data/lang_char_bpe \
26
+ --epoch 99 \
27
+ --avg 1 \
28
+ --use-averaged-model 0 \
29
+ \
30
+ --decode-chunk-len 32 \
31
+ --num-encoder-layers "2,4,3,2,4" \
32
+ --feedforward-dims "1024,1024,1536,1536,1024" \
33
+ --nhead "8,8,8,8,8" \
34
+ --encoder-dims "384,384,384,384,384" \
35
+ --attention-dims "192,192,192,192,192" \
36
+ --encoder-unmasked-dims "256,256,256,256,256" \
37
+ --zipformer-downsampling-factors "1,2,4,8,2" \
38
+ --cnn-module-kernels "31,31,31,31,31" \
39
+ --decoder-dim 512 \
40
+ --joiner-dim 512
41
+
42
+ exit 0
43
+
exp/joiner_jit_trace.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7628036f64e4c0281d02ccb696df74376297645e0400265620bf3806c31e5621
3
+ size 14679599