|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function |
|
|
|
import argparse |
|
import os |
|
|
|
import torch |
|
import yaml |
|
|
|
from wenet.utils.checkpoint import load_checkpoint |
|
from wenet.utils.init_model import init_model |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser(description="export your script model") |
|
parser.add_argument("--config", required=True, help="config file") |
|
parser.add_argument("--checkpoint", required=True, help="checkpoint model") |
|
parser.add_argument("--output_file", default=None, help="output file") |
|
parser.add_argument( |
|
"--output_quant_file", default=None, help="output quantized model file" |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
|
|
with open(args.config, "r") as fin: |
|
configs = yaml.load(fin, Loader=yaml.FullLoader) |
|
model = init_model(configs) |
|
print(model) |
|
|
|
load_checkpoint(model, args.checkpoint) |
|
|
|
|
|
if args.output_file: |
|
script_model = torch.jit.script(model) |
|
script_model.save(args.output_file) |
|
print("Export model successfully, see {}".format(args.output_file)) |
|
|
|
|
|
if args.output_quant_file: |
|
quantized_model = torch.quantization.quantize_dynamic( |
|
model, {torch.nn.Linear}, dtype=torch.qint8 |
|
) |
|
print(quantized_model) |
|
script_quant_model = torch.jit.script(quantized_model) |
|
script_quant_model.save(args.output_quant_file) |
|
print( |
|
"Export quantized model successfully, " |
|
"see {}".format(args.output_quant_file) |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|