|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import argparse |
|
import os |
|
import copy |
|
import sys |
|
|
|
import torch |
|
import yaml |
|
import numpy as np |
|
|
|
from wenet.utils.checkpoint import load_checkpoint |
|
from wenet.utils.init_model import init_model |
|
|
|
try: |
|
import onnx |
|
import onnxruntime |
|
from onnxruntime.quantization import quantize_dynamic, QuantType |
|
except ImportError: |
|
print("Please install onnx and onnxruntime!") |
|
sys.exit(1) |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description="export your script model") |
|
parser.add_argument("--config", required=True, help="config file") |
|
parser.add_argument("--checkpoint", required=True, help="checkpoint model") |
|
parser.add_argument("--output_dir", required=True, help="output directory") |
|
parser.add_argument( |
|
"--chunk_size", required=True, type=int, help="decoding chunk size" |
|
) |
|
parser.add_argument( |
|
"--num_decoding_left_chunks", required=True, type=int, help="cache chunks" |
|
) |
|
parser.add_argument( |
|
"--reverse_weight", |
|
default=0.5, |
|
type=float, |
|
help="reverse_weight in attention_rescoing", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def to_numpy(tensor): |
|
if tensor.requires_grad: |
|
return tensor.detach().cpu().numpy() |
|
else: |
|
return tensor.cpu().numpy() |
|
|
|
|
|
def print_input_output_info(onnx_model, name, prefix="\t\t"): |
|
input_names = [node.name for node in onnx_model.graph.input] |
|
input_shapes = [ |
|
[d.dim_value for d in node.type.tensor_type.shape.dim] |
|
for node in onnx_model.graph.input |
|
] |
|
output_names = [node.name for node in onnx_model.graph.output] |
|
output_shapes = [ |
|
[d.dim_value for d in node.type.tensor_type.shape.dim] |
|
for node in onnx_model.graph.output |
|
] |
|
print("{}{} inputs : {}".format(prefix, name, input_names)) |
|
print("{}{} input shapes : {}".format(prefix, name, input_shapes)) |
|
print("{}{} outputs: {}".format(prefix, name, output_names)) |
|
print("{}{} output shapes : {}".format(prefix, name, output_shapes)) |
|
|
|
|
|
def export_encoder(asr_model, args): |
|
print("Stage-1: export encoder") |
|
encoder = asr_model.encoder |
|
encoder.forward = encoder.forward_chunk |
|
encoder_outpath = os.path.join(args["output_dir"], "encoder.onnx") |
|
|
|
print("\tStage-1.1: prepare inputs for encoder") |
|
chunk = torch.randn((args["batch"], args["decoding_window"], args["feature_size"])) |
|
offset = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args["left_chunks"] > 0: |
|
required_cache_size = args["chunk_size"] * args["left_chunks"] |
|
offset = required_cache_size |
|
|
|
att_cache = torch.zeros( |
|
( |
|
args["num_blocks"], |
|
args["head"], |
|
required_cache_size, |
|
args["output_size"] // args["head"] * 2, |
|
) |
|
) |
|
|
|
att_mask = torch.ones( |
|
(args["batch"], 1, required_cache_size + args["chunk_size"]), |
|
dtype=torch.bool, |
|
) |
|
att_mask[:, :, :required_cache_size] = 0 |
|
elif args["left_chunks"] <= 0: |
|
required_cache_size = -1 if args["left_chunks"] < 0 else 0 |
|
|
|
att_cache = torch.zeros( |
|
( |
|
args["num_blocks"], |
|
args["head"], |
|
0, |
|
args["output_size"] // args["head"] * 2, |
|
) |
|
) |
|
|
|
att_mask = torch.ones((0, 0, 0), dtype=torch.bool) |
|
cnn_cache = torch.zeros( |
|
( |
|
args["num_blocks"], |
|
args["batch"], |
|
args["output_size"], |
|
args["cnn_module_kernel"] - 1, |
|
) |
|
) |
|
inputs = (chunk, offset, required_cache_size, att_cache, cnn_cache, att_mask) |
|
print( |
|
"\t\tchunk.size(): {}\n".format(chunk.size()), |
|
"\t\toffset: {}\n".format(offset), |
|
"\t\trequired_cache: {}\n".format(required_cache_size), |
|
"\t\tatt_cache.size(): {}\n".format(att_cache.size()), |
|
"\t\tcnn_cache.size(): {}\n".format(cnn_cache.size()), |
|
"\t\tatt_mask.size(): {}\n".format(att_mask.size()), |
|
) |
|
|
|
print("\tStage-1.2: torch.onnx.export") |
|
dynamic_axes = { |
|
"chunk": {1: "T"}, |
|
"att_cache": {2: "T_CACHE"}, |
|
"att_mask": {2: "T_ADD_T_CACHE"}, |
|
"output": {1: "T"}, |
|
"r_att_cache": {2: "T_CACHE"}, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch.onnx.export( |
|
encoder, |
|
inputs, |
|
encoder_outpath, |
|
opset_version=13, |
|
export_params=True, |
|
do_constant_folding=True, |
|
input_names=[ |
|
"chunk", |
|
"offset", |
|
"required_cache_size", |
|
"att_cache", |
|
"cnn_cache", |
|
"att_mask", |
|
], |
|
output_names=["output", "r_att_cache", "r_cnn_cache"], |
|
dynamic_axes=dynamic_axes, |
|
verbose=False, |
|
) |
|
onnx_encoder = onnx.load(encoder_outpath) |
|
for k, v in args.items(): |
|
meta = onnx_encoder.metadata_props.add() |
|
meta.key, meta.value = str(k), str(v) |
|
onnx.checker.check_model(onnx_encoder) |
|
onnx.helper.printable_graph(onnx_encoder.graph) |
|
|
|
|
|
onnx.save(onnx_encoder, encoder_outpath) |
|
print_input_output_info(onnx_encoder, "onnx_encoder") |
|
|
|
model_fp32 = encoder_outpath |
|
model_quant = os.path.join(args["output_dir"], "encoder.quant.onnx") |
|
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) |
|
print("\t\tExport onnx_encoder, done! see {}".format(encoder_outpath)) |
|
|
|
print("\tStage-1.3: check onnx_encoder and torch_encoder") |
|
torch_output = [] |
|
torch_chunk = copy.deepcopy(chunk) |
|
torch_offset = copy.deepcopy(offset) |
|
torch_required_cache_size = copy.deepcopy(required_cache_size) |
|
torch_att_cache = copy.deepcopy(att_cache) |
|
torch_cnn_cache = copy.deepcopy(cnn_cache) |
|
torch_att_mask = copy.deepcopy(att_mask) |
|
for i in range(10): |
|
print( |
|
"\t\ttorch chunk-{}: {}, offset: {}, att_cache: {}," |
|
" cnn_cache: {}, att_mask: {}".format( |
|
i, |
|
list(torch_chunk.size()), |
|
torch_offset, |
|
list(torch_att_cache.size()), |
|
list(torch_cnn_cache.size()), |
|
list(torch_att_mask.size()), |
|
) |
|
) |
|
|
|
|
|
if args["left_chunks"] > 0: |
|
torch_att_mask[:, :, -(args["chunk_size"] * (i + 1)) :] = 1 |
|
out, torch_att_cache, torch_cnn_cache = encoder( |
|
torch_chunk, |
|
torch_offset, |
|
torch_required_cache_size, |
|
torch_att_cache, |
|
torch_cnn_cache, |
|
torch_att_mask, |
|
) |
|
torch_output.append(out) |
|
torch_offset += out.size(1) |
|
torch_output = torch.cat(torch_output, dim=1) |
|
|
|
onnx_output = [] |
|
onnx_chunk = to_numpy(chunk) |
|
onnx_offset = np.array((offset)).astype(np.int64) |
|
onnx_required_cache_size = np.array((required_cache_size)).astype(np.int64) |
|
onnx_att_cache = to_numpy(att_cache) |
|
onnx_cnn_cache = to_numpy(cnn_cache) |
|
onnx_att_mask = to_numpy(att_mask) |
|
ort_session = onnxruntime.InferenceSession(encoder_outpath) |
|
input_names = [node.name for node in onnx_encoder.graph.input] |
|
for i in range(10): |
|
print( |
|
"\t\tonnx chunk-{}: {}, offset: {}, att_cache: {}," |
|
" cnn_cache: {}, att_mask: {}".format( |
|
i, |
|
onnx_chunk.shape, |
|
onnx_offset, |
|
onnx_att_cache.shape, |
|
onnx_cnn_cache.shape, |
|
onnx_att_mask.shape, |
|
) |
|
) |
|
|
|
|
|
if args["left_chunks"] > 0: |
|
onnx_att_mask[:, :, -(args["chunk_size"] * (i + 1)) :] = 1 |
|
ort_inputs = { |
|
"chunk": onnx_chunk, |
|
"offset": onnx_offset, |
|
"required_cache_size": onnx_required_cache_size, |
|
"att_cache": onnx_att_cache, |
|
"cnn_cache": onnx_cnn_cache, |
|
"att_mask": onnx_att_mask, |
|
} |
|
|
|
|
|
|
|
|
|
for k in list(ort_inputs): |
|
if k not in input_names: |
|
ort_inputs.pop(k) |
|
ort_outs = ort_session.run(None, ort_inputs) |
|
onnx_att_cache, onnx_cnn_cache = ort_outs[1], ort_outs[2] |
|
onnx_output.append(ort_outs[0]) |
|
onnx_offset += ort_outs[0].shape[1] |
|
onnx_output = np.concatenate(onnx_output, axis=1) |
|
|
|
np.testing.assert_allclose( |
|
to_numpy(torch_output), onnx_output, rtol=1e-03, atol=1e-05 |
|
) |
|
meta = ort_session.get_modelmeta() |
|
print("\t\tcustom_metadata_map={}".format(meta.custom_metadata_map)) |
|
print("\t\tCheck onnx_encoder, pass!") |
|
|
|
|
|
def export_ctc(asr_model, args): |
|
print("Stage-2: export ctc") |
|
ctc = asr_model.ctc |
|
ctc.forward = ctc.log_softmax |
|
ctc_outpath = os.path.join(args["output_dir"], "ctc.onnx") |
|
|
|
print("\tStage-2.1: prepare inputs for ctc") |
|
hidden = torch.randn( |
|
( |
|
args["batch"], |
|
args["chunk_size"] if args["chunk_size"] > 0 else 16, |
|
args["output_size"], |
|
) |
|
) |
|
|
|
print("\tStage-2.2: torch.onnx.export") |
|
dynamic_axes = {"hidden": {1: "T"}, "probs": {1: "T"}} |
|
torch.onnx.export( |
|
ctc, |
|
hidden, |
|
ctc_outpath, |
|
opset_version=13, |
|
export_params=True, |
|
do_constant_folding=True, |
|
input_names=["hidden"], |
|
output_names=["probs"], |
|
dynamic_axes=dynamic_axes, |
|
verbose=False, |
|
) |
|
onnx_ctc = onnx.load(ctc_outpath) |
|
for k, v in args.items(): |
|
meta = onnx_ctc.metadata_props.add() |
|
meta.key, meta.value = str(k), str(v) |
|
onnx.checker.check_model(onnx_ctc) |
|
onnx.helper.printable_graph(onnx_ctc.graph) |
|
onnx.save(onnx_ctc, ctc_outpath) |
|
print_input_output_info(onnx_ctc, "onnx_ctc") |
|
|
|
model_fp32 = ctc_outpath |
|
model_quant = os.path.join(args["output_dir"], "ctc.quant.onnx") |
|
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) |
|
print("\t\tExport onnx_ctc, done! see {}".format(ctc_outpath)) |
|
|
|
print("\tStage-2.3: check onnx_ctc and torch_ctc") |
|
torch_output = ctc(hidden) |
|
ort_session = onnxruntime.InferenceSession(ctc_outpath) |
|
onnx_output = ort_session.run(None, {"hidden": to_numpy(hidden)}) |
|
|
|
np.testing.assert_allclose( |
|
to_numpy(torch_output), onnx_output[0], rtol=1e-03, atol=1e-05 |
|
) |
|
print("\t\tCheck onnx_ctc, pass!") |
|
|
|
|
|
def export_decoder(asr_model, args): |
|
print("Stage-3: export decoder") |
|
decoder = asr_model |
|
|
|
|
|
decoder.forward = decoder.forward_attention_decoder |
|
decoder_outpath = os.path.join(args["output_dir"], "decoder.onnx") |
|
|
|
print("\tStage-3.1: prepare inputs for decoder") |
|
|
|
encoder_out = torch.randn((1, 200, args["output_size"])) |
|
hyps = torch.randint(low=0, high=args["vocab_size"], size=[10, 20]) |
|
hyps[:, 0] = args["vocab_size"] - 1 |
|
hyps_lens = torch.randint(low=15, high=21, size=[10]) |
|
|
|
print("\tStage-3.2: torch.onnx.export") |
|
dynamic_axes = { |
|
"hyps": {0: "NBEST", 1: "L"}, |
|
"hyps_lens": {0: "NBEST"}, |
|
"encoder_out": {1: "T"}, |
|
"score": {0: "NBEST", 1: "L"}, |
|
"r_score": {0: "NBEST", 1: "L"}, |
|
} |
|
inputs = (hyps, hyps_lens, encoder_out, args["reverse_weight"]) |
|
torch.onnx.export( |
|
decoder, |
|
inputs, |
|
decoder_outpath, |
|
opset_version=13, |
|
export_params=True, |
|
do_constant_folding=True, |
|
input_names=["hyps", "hyps_lens", "encoder_out", "reverse_weight"], |
|
output_names=["score", "r_score"], |
|
dynamic_axes=dynamic_axes, |
|
verbose=False, |
|
) |
|
onnx_decoder = onnx.load(decoder_outpath) |
|
for k, v in args.items(): |
|
meta = onnx_decoder.metadata_props.add() |
|
meta.key, meta.value = str(k), str(v) |
|
onnx.checker.check_model(onnx_decoder) |
|
onnx.helper.printable_graph(onnx_decoder.graph) |
|
onnx.save(onnx_decoder, decoder_outpath) |
|
print_input_output_info(onnx_decoder, "onnx_decoder") |
|
model_fp32 = decoder_outpath |
|
model_quant = os.path.join(args["output_dir"], "decoder.quant.onnx") |
|
quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) |
|
print("\t\tExport onnx_decoder, done! see {}".format(decoder_outpath)) |
|
|
|
print("\tStage-3.3: check onnx_decoder and torch_decoder") |
|
torch_score, torch_r_score = decoder( |
|
hyps, hyps_lens, encoder_out, args["reverse_weight"] |
|
) |
|
ort_session = onnxruntime.InferenceSession(decoder_outpath) |
|
input_names = [node.name for node in onnx_decoder.graph.input] |
|
ort_inputs = { |
|
"hyps": to_numpy(hyps), |
|
"hyps_lens": to_numpy(hyps_lens), |
|
"encoder_out": to_numpy(encoder_out), |
|
"reverse_weight": np.array((args["reverse_weight"])), |
|
} |
|
for k in list(ort_inputs): |
|
if k not in input_names: |
|
ort_inputs.pop(k) |
|
onnx_output = ort_session.run(None, ort_inputs) |
|
|
|
np.testing.assert_allclose( |
|
to_numpy(torch_score), onnx_output[0], rtol=1e-03, atol=1e-05 |
|
) |
|
if args["is_bidirectional_decoder"] and args["reverse_weight"] > 0.0: |
|
np.testing.assert_allclose( |
|
to_numpy(torch_r_score), onnx_output[1], rtol=1e-03, atol=1e-05 |
|
) |
|
print("\t\tCheck onnx_decoder, pass!") |
|
|
|
|
|
def main(): |
|
torch.manual_seed(777) |
|
args = get_args() |
|
output_dir = args.output_dir |
|
os.system("mkdir -p " + output_dir) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
with open(args.config, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
|
|
model = init_model(configs) |
|
load_checkpoint(model, args.checkpoint) |
|
model.eval() |
|
print(model) |
|
|
|
arguments = {} |
|
arguments["output_dir"] = output_dir |
|
arguments["batch"] = 1 |
|
arguments["chunk_size"] = args.chunk_size |
|
arguments["left_chunks"] = args.num_decoding_left_chunks |
|
arguments["reverse_weight"] = args.reverse_weight |
|
arguments["output_size"] = configs["encoder_conf"]["output_size"] |
|
arguments["num_blocks"] = configs["encoder_conf"]["num_blocks"] |
|
arguments["cnn_module_kernel"] = configs["encoder_conf"].get("cnn_module_kernel", 1) |
|
arguments["head"] = configs["encoder_conf"]["attention_heads"] |
|
arguments["feature_size"] = configs["input_dim"] |
|
arguments["vocab_size"] = configs["output_dim"] |
|
|
|
arguments["decoding_window"] = ( |
|
(args.chunk_size - 1) * model.encoder.embed.subsampling_rate |
|
+ model.encoder.embed.right_context |
|
+ 1 |
|
if args.chunk_size > 0 |
|
else 67 |
|
) |
|
arguments["encoder"] = configs["encoder"] |
|
arguments["decoder"] = configs["decoder"] |
|
arguments["subsampling_rate"] = model.subsampling_rate() |
|
arguments["right_context"] = model.right_context() |
|
arguments["sos_symbol"] = model.sos_symbol() |
|
arguments["eos_symbol"] = model.eos_symbol() |
|
arguments["is_bidirectional_decoder"] = 1 if model.is_bidirectional_decoder() else 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
if arguments["left_chunks"] > 0: |
|
assert arguments["chunk_size"] > 0 |
|
|
|
export_encoder(model, arguments) |
|
export_ctc(model, arguments) |
|
export_decoder(model, arguments) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|