File size: 18,887 Bytes
f196feb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 |
# This module is from [WeNet](https://github.com/wenet-e2e/wenet).
# ## Citations
# ```bibtex
# @inproceedings{yao2021wenet,
# title={WeNet: Production oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit},
# author={Yao, Zhuoyuan and Wu, Di and Wang, Xiong and Zhang, Binbin and Yu, Fan and Yang, Chao and Peng, Zhendong and Chen, Xiaoyu and Xie, Lei and Lei, Xin},
# booktitle={Proc. Interspeech},
# year={2021},
# address={Brno, Czech Republic },
# organization={IEEE}
# }
# @article{zhang2022wenet,
# title={WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit},
# author={Zhang, Binbin and Wu, Di and Peng, Zhendong and Song, Xingchen and Yao, Zhuoyuan and Lv, Hang and Xie, Lei and Yang, Chao and Pan, Fuping and Niu, Jianwei},
# journal={arXiv preprint arXiv:2203.15455},
# year={2022}
# }
#
from __future__ import print_function
import argparse
import os
import copy
import sys
import torch
import yaml
import numpy as np
from wenet.utils.checkpoint import load_checkpoint
from wenet.utils.init_model import init_model
try:
import onnx
import onnxruntime
from onnxruntime.quantization import quantize_dynamic, QuantType
except ImportError:
print("Please install onnx and onnxruntime!")
sys.exit(1)
def get_args():
parser = argparse.ArgumentParser(description="export your script model")
parser.add_argument("--config", required=True, help="config file")
parser.add_argument("--checkpoint", required=True, help="checkpoint model")
parser.add_argument("--output_dir", required=True, help="output directory")
parser.add_argument(
"--chunk_size", required=True, type=int, help="decoding chunk size"
)
parser.add_argument(
"--num_decoding_left_chunks", required=True, type=int, help="cache chunks"
)
parser.add_argument(
"--reverse_weight",
default=0.5,
type=float,
help="reverse_weight in attention_rescoing",
)
args = parser.parse_args()
return args
def to_numpy(tensor):
if tensor.requires_grad:
return tensor.detach().cpu().numpy()
else:
return tensor.cpu().numpy()
def print_input_output_info(onnx_model, name, prefix="\t\t"):
input_names = [node.name for node in onnx_model.graph.input]
input_shapes = [
[d.dim_value for d in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.input
]
output_names = [node.name for node in onnx_model.graph.output]
output_shapes = [
[d.dim_value for d in node.type.tensor_type.shape.dim]
for node in onnx_model.graph.output
]
print("{}{} inputs : {}".format(prefix, name, input_names))
print("{}{} input shapes : {}".format(prefix, name, input_shapes))
print("{}{} outputs: {}".format(prefix, name, output_names))
print("{}{} output shapes : {}".format(prefix, name, output_shapes))
def export_encoder(asr_model, args):
print("Stage-1: export encoder")
encoder = asr_model.encoder
encoder.forward = encoder.forward_chunk
encoder_outpath = os.path.join(args["output_dir"], "encoder.onnx")
print("\tStage-1.1: prepare inputs for encoder")
chunk = torch.randn((args["batch"], args["decoding_window"], args["feature_size"]))
offset = 0
# NOTE(xcsong): The uncertainty of `next_cache_start` only appears
# in the first few chunks, this is caused by dynamic att_cache shape, i,e
# (0, 0, 0, 0) for 1st chunk and (elayers, head, ?, d_k*2) for subsequent
# chunks. One way to ease the ONNX export is to keep `next_cache_start`
# as a fixed value. To do this, for the **first** chunk, if
# left_chunks > 0, we feed real cache & real mask to the model, otherwise
# fake cache & fake mask. In this way, we get:
# 1. 16/-1 mode: next_cache_start == 0 for all chunks
# 2. 16/4 mode: next_cache_start == chunk_size for all chunks
# 3. 16/0 mode: next_cache_start == chunk_size for all chunks
# 4. -1/-1 mode: next_cache_start == 0 for all chunks
# NO MORE DYNAMIC CHANGES!!
#
# NOTE(Mddct): We retain the current design for the convenience of supporting some
# inference frameworks without dynamic shapes. If you're interested in all-in-one
# model that supports different chunks please see:
# https://github.com/wenet-e2e/wenet/pull/1174
if args["left_chunks"] > 0: # 16/4
required_cache_size = args["chunk_size"] * args["left_chunks"]
offset = required_cache_size
# Real cache
att_cache = torch.zeros(
(
args["num_blocks"],
args["head"],
required_cache_size,
args["output_size"] // args["head"] * 2,
)
)
# Real mask
att_mask = torch.ones(
(args["batch"], 1, required_cache_size + args["chunk_size"]),
dtype=torch.bool,
)
att_mask[:, :, :required_cache_size] = 0
elif args["left_chunks"] <= 0: # 16/-1, -1/-1, 16/0
required_cache_size = -1 if args["left_chunks"] < 0 else 0
# Fake cache
att_cache = torch.zeros(
(
args["num_blocks"],
args["head"],
0,
args["output_size"] // args["head"] * 2,
)
)
# Fake mask
att_mask = torch.ones((0, 0, 0), dtype=torch.bool)
cnn_cache = torch.zeros(
(
args["num_blocks"],
args["batch"],
args["output_size"],
args["cnn_module_kernel"] - 1,
)
)
inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache, att_mask)
print(
"\t\tchunk.size(): {}\n".format(chunk.size()),
"\t\toffset: {}\n".format(offset),
"\t\trequired_cache: {}\n".format(required_cache_size),
"\t\tatt_cache.size(): {}\n".format(att_cache.size()),
"\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()),
"\t\tatt_mask.size(): {}\n".format(att_mask.size()),
)
print("\tStage-1.2: torch.onnx.export")
dynamic_axes = {
"chunk": {1: "T"},
"att_cache": {2: "T_CACHE"},
"att_mask": {2: "T_ADD_T_CACHE"},
"output": {1: "T"},
"r_att_cache": {2: "T_CACHE"},
}
# NOTE(xcsong): We keep dynamic axes even if in 16/4 mode, this is
# to avoid padding the last chunk (which usually contains less
# frames than required). For users who want static axes, just pop
# out specific axis.
# if args['chunk_size'] > 0: # 16/4, 16/-1, 16/0
# dynamic_axes.pop('chunk')
# dynamic_axes.pop('output')
# if args['left_chunks'] >= 0: # 16/4, 16/0
# # NOTE(xsong): since we feed real cache & real mask into the
# # model when left_chunks > 0, the shape of cache will never
# # be changed.
# dynamic_axes.pop('att_cache')
# dynamic_axes.pop('r_att_cache')
torch.onnx.export(
encoder,
inputs,
encoder_outpath,
opset_version=13,
export_params=True,
do_constant_folding=True,
input_names=[
"chunk",
"offset",
"required_cache_size",
"att_cache",
"cnn_cache",
"att_mask",
],
output_names=["output", "r_att_cache", "r_cnn_cache"],
dynamic_axes=dynamic_axes,
verbose=False,
)
onnx_encoder = onnx.load(encoder_outpath)
for k, v in args.items():
meta = onnx_encoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_encoder)
onnx.helper.printable_graph(onnx_encoder.graph)
# NOTE(xcsong): to add those metadatas we need to reopen
# the file and resave it.
onnx.save(onnx_encoder, encoder_outpath)
print_input_output_info(onnx_encoder, "onnx_encoder")
# Dynamic quantization
model_fp32 = encoder_outpath
model_quant = os.path.join(args["output_dir"], "encoder.quant.onnx")
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print("\t\tExport onnx_encoder, done! see {}".format(encoder_outpath))
print("\tStage-1.3: check onnx_encoder and torch_encoder")
torch_output = []
torch_chunk = copy.deepcopy(chunk)
torch_offset = copy.deepcopy(offset)
torch_required_cache_size = copy.deepcopy(required_cache_size)
torch_att_cache = copy.deepcopy(att_cache)
torch_cnn_cache = copy.deepcopy(cnn_cache)
torch_att_mask = copy.deepcopy(att_mask)
for i in range(10):
print(
"\t\ttorch chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}".format(
i,
list(torch_chunk.size()),
torch_offset,
list(torch_att_cache.size()),
list(torch_cnn_cache.size()),
list(torch_att_mask.size()),
)
)
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if args["left_chunks"] > 0: # 16/4
torch_att_mask[:, :, -(args["chunk_size"] * (i + 1)) :] = 1
out, torch_att_cache, torch_cnn_cache = encoder(
torch_chunk,
torch_offset,
torch_required_cache_size,
torch_att_cache,
torch_cnn_cache,
torch_att_mask,
)
torch_output.append(out)
torch_offset += out.size(1)
torch_output = torch.cat(torch_output, dim=1)
onnx_output = []
onnx_chunk = to_numpy(chunk)
onnx_offset = np.array((offset)).astype(np.int64)
onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64)
onnx_att_cache = to_numpy(att_cache)
onnx_cnn_cache = to_numpy(cnn_cache)
onnx_att_mask = to_numpy(att_mask)
ort_session = onnxruntime.InferenceSession(encoder_outpath)
input_names = [node.name for node in onnx_encoder.graph.input]
for i in range(10):
print(
"\t\tonnx chunk-{}: {}, offset: {}, att_cache: {},"
" cnn_cache: {}, att_mask: {}".format(
i,
onnx_chunk.shape,
onnx_offset,
onnx_att_cache.shape,
onnx_cnn_cache.shape,
onnx_att_mask.shape,
)
)
# NOTE(xsong): att_mask of the first few batches need changes if
# we use 16/4 mode.
if args["left_chunks"] > 0: # 16/4
onnx_att_mask[:, :, -(args["chunk_size"] * (i + 1)) :] = 1
ort_inputs = {
"chunk": onnx_chunk,
"offset": onnx_offset,
"required_cache_size": onnx_required_cache_size,
"att_cache": onnx_att_cache,
"cnn_cache": onnx_cnn_cache,
"att_mask": onnx_att_mask,
}
# NOTE(xcsong): If we use 16/-1, -1/-1 or 16/0 mode, `next_cache_start`
# will be hardcoded to 0 or chunk_size by ONNX, thus
# required_cache_size and att_mask are no more needed and they will
# be removed by ONNX automatically.
for k in list(ort_inputs):
if k not in input_names:
ort_inputs.pop(k)
ort_outs = ort_session.run(None, ort_inputs)
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2]
onnx_output.append(ort_outs[0])
onnx_offset += ort_outs[0].shape[1]
onnx_output = np.concatenate(onnx_output, axis=1)
np.testing.assert_allclose(
to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-05
)
meta = ort_session.get_modelmeta()
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map))
print("\t\tCheck onnx_encoder, pass!")
def export_ctc(asr_model, args):
print("Stage-2: export ctc")
ctc = asr_model.ctc
ctc.forward = ctc.log_softmax
ctc_outpath = os.path.join(args["output_dir"], "ctc.onnx")
print("\tStage-2.1: prepare inputs for ctc")
hidden = torch.randn(
(
args["batch"],
args["chunk_size"] if args["chunk_size"] > 0 else 16,
args["output_size"],
)
)
print("\tStage-2.2: torch.onnx.export")
dynamic_axes = {"hidden": {1: "T"}, "probs": {1: "T"}}
torch.onnx.export(
ctc,
hidden,
ctc_outpath,
opset_version=13,
export_params=True,
do_constant_folding=True,
input_names=["hidden"],
output_names=["probs"],
dynamic_axes=dynamic_axes,
verbose=False,
)
onnx_ctc = onnx.load(ctc_outpath)
for k, v in args.items():
meta = onnx_ctc.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_ctc)
onnx.helper.printable_graph(onnx_ctc.graph)
onnx.save(onnx_ctc, ctc_outpath)
print_input_output_info(onnx_ctc, "onnx_ctc")
# Dynamic quantization
model_fp32 = ctc_outpath
model_quant = os.path.join(args["output_dir"], "ctc.quant.onnx")
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print("\t\tExport onnx_ctc, done! see {}".format(ctc_outpath))
print("\tStage-2.3: check onnx_ctc and torch_ctc")
torch_output = ctc(hidden)
ort_session = onnxruntime.InferenceSession(ctc_outpath)
onnx_output = ort_session.run(None, {"hidden": to_numpy(hidden)})
np.testing.assert_allclose(
to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-05
)
print("\t\tCheck onnx_ctc, pass!")
def export_decoder(asr_model, args):
print("Stage-3: export decoder")
decoder = asr_model
# NOTE(lzhin): parameters of encoder will be automatically removed
# since they are not used during rescoring.
decoder.forward = decoder.forward_attention_decoder
decoder_outpath = os.path.join(args["output_dir"], "decoder.onnx")
print("\tStage-3.1: prepare inputs for decoder")
# hardcode time->200 nbest->10 len->20, they are dynamic axes.
encoder_out = torch.randn((1, 200, args["output_size"]))
hyps = torch.randint(low=0, high=args["vocab_size"], size=[10, 20])
hyps[:, 0] = args["vocab_size"] - 1 # <sos>
hyps_lens = torch.randint(low=15, high=21, size=[10])
print("\tStage-3.2: torch.onnx.export")
dynamic_axes = {
"hyps": {0: "NBEST", 1: "L"},
"hyps_lens": {0: "NBEST"},
"encoder_out": {1: "T"},
"score": {0: "NBEST", 1: "L"},
"r_score": {0: "NBEST", 1: "L"},
}
inputs = (hyps, hyps_lens, encoder_out, args["reverse_weight"])
torch.onnx.export(
decoder,
inputs,
decoder_outpath,
opset_version=13,
export_params=True,
do_constant_folding=True,
input_names=["hyps", "hyps_lens", "encoder_out", "reverse_weight"],
output_names=["score", "r_score"],
dynamic_axes=dynamic_axes,
verbose=False,
)
onnx_decoder = onnx.load(decoder_outpath)
for k, v in args.items():
meta = onnx_decoder.metadata_props.add()
meta.key, meta.value = str(k), str(v)
onnx.checker.check_model(onnx_decoder)
onnx.helper.printable_graph(onnx_decoder.graph)
onnx.save(onnx_decoder, decoder_outpath)
print_input_output_info(onnx_decoder, "onnx_decoder")
model_fp32 = decoder_outpath
model_quant = os.path.join(args["output_dir"], "decoder.quant.onnx")
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8)
print("\t\tExport onnx_decoder, done! see {}".format(decoder_outpath))
print("\tStage-3.3: check onnx_decoder and torch_decoder")
torch_score, torch_r_score = decoder(
hyps, hyps_lens, encoder_out, args["reverse_weight"]
)
ort_session = onnxruntime.InferenceSession(decoder_outpath)
input_names = [node.name for node in onnx_decoder.graph.input]
ort_inputs = {
"hyps": to_numpy(hyps),
"hyps_lens": to_numpy(hyps_lens),
"encoder_out": to_numpy(encoder_out),
"reverse_weight": np.array((args["reverse_weight"])),
}
for k in list(ort_inputs):
if k not in input_names:
ort_inputs.pop(k)
onnx_output = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(
to_numpy(torch_score), onnx_output[0], rtol=1e-03, atol=1e-05
)
if args["is_bidirectional_decoder"] and args["reverse_weight"] > 0.0:
np.testing.assert_allclose(
to_numpy(torch_r_score), onnx_output[1], rtol=1e-03, atol=1e-05
)
print("\t\tCheck onnx_decoder, pass!")
def main():
torch.manual_seed(777)
args = get_args()
output_dir = args.output_dir
os.system("mkdir -p " + output_dir)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
with open(args.config, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
model = init_model(configs)
load_checkpoint(model, args.checkpoint)
model.eval()
print(model)
arguments = {}
arguments["output_dir"] = output_dir
arguments["batch"] = 1
arguments["chunk_size"] = args.chunk_size
arguments["left_chunks"] = args.num_decoding_left_chunks
arguments["reverse_weight"] = args.reverse_weight
arguments["output_size"] = configs["encoder_conf"]["output_size"]
arguments["num_blocks"] = configs["encoder_conf"]["num_blocks"]
arguments["cnn_module_kernel"] = configs["encoder_conf"].get("cnn_module_kernel", 1)
arguments["head"] = configs["encoder_conf"]["attention_heads"]
arguments["feature_size"] = configs["input_dim"]
arguments["vocab_size"] = configs["output_dim"]
# NOTE(xcsong): if chunk_size == -1, hardcode to 67
arguments["decoding_window"] = (
(args.chunk_size - 1) * model.encoder.embed.subsampling_rate
+ model.encoder.embed.right_context
+ 1
if args.chunk_size > 0
else 67
)
arguments["encoder"] = configs["encoder"]
arguments["decoder"] = configs["decoder"]
arguments["subsampling_rate"] = model.subsampling_rate()
arguments["right_context"] = model.right_context()
arguments["sos_symbol"] = model.sos_symbol()
arguments["eos_symbol"] = model.eos_symbol()
arguments["is_bidirectional_decoder"] = 1 if model.is_bidirectional_decoder() else 0
# NOTE(xcsong): Please note that -1/-1 means non-streaming model! It is
# not a [16/4 16/-1 16/0] all-in-one model and it should not be used in
# streaming mode (i.e., setting chunk_size=16 in `decoder_main`). If you
# want to use 16/-1 or any other streaming mode in `decoder_main`,
# please export onnx in the same config.
if arguments["left_chunks"] > 0:
assert arguments["chunk_size"] > 0 # -1/4 not supported
export_encoder(model, arguments)
export_ctc(model, arguments)
export_decoder(model, arguments)
if __name__ == "__main__":
main()
|