|
|
|
|
|
|
|
import datetime |
|
import argparse |
|
from rknn.api import RKNN |
|
from sys import exit |
|
|
|
AUDIO_LENGTH = 645 |
|
TEXT_LENGTH = 64 |
|
|
|
|
|
MODELS = { |
|
'transformer': 'transformer.onnx', |
|
'vae_decoder': 'vae_decoder.onnx', |
|
} |
|
|
|
SHAPES = { |
|
'transformer': [ |
|
[ |
|
[1, AUDIO_LENGTH, 64], |
|
[1,], |
|
[2, 1024], |
|
[2, TEXT_LENGTH, 1024], |
|
[1, TEXT_LENGTH, 3], |
|
[1, AUDIO_LENGTH, 3], |
|
], |
|
], |
|
'vae_decoder': [ |
|
[ |
|
[1, 64, AUDIO_LENGTH], |
|
], |
|
], |
|
} |
|
|
|
QUANTIZE=False |
|
detailed_performance_log = True |
|
|
|
def convert_model(model_type): |
|
"""转换指定类型的模型到RKNN格式""" |
|
if model_type not in MODELS: |
|
print(f"错误: 不支持的模型类型 {model_type}") |
|
return False |
|
|
|
onnx_model = MODELS[model_type] |
|
rknn_model = onnx_model.replace(".onnx",".rknn") |
|
|
|
timedate_iso = datetime.datetime.now().isoformat() |
|
|
|
rknn = RKNN(verbose=True) |
|
rknn.config( |
|
quantized_dtype='w8a8', |
|
quantized_algorithm='normal', |
|
quantized_method='channel', |
|
quantized_hybrid_level=0, |
|
target_platform='rk3588', |
|
quant_img_RGB2BGR = False, |
|
float_dtype='float16', |
|
optimization_level=3, |
|
custom_string=f"converted at {timedate_iso}", |
|
remove_weight=False, |
|
compress_weight=False, |
|
inputs_yuv_fmt=None, |
|
single_core_mode=False, |
|
dynamic_input=SHAPES[model_type], |
|
model_pruning=False, |
|
op_target=None, |
|
quantize_weight=False, |
|
remove_reshape=False, |
|
sparse_infer=False, |
|
enable_flash_attention=False, |
|
|
|
) |
|
|
|
print(f"开始转换 {model_type} 模型...") |
|
ret = rknn.load_onnx(model=onnx_model) |
|
if ret != 0: |
|
print("加载ONNX模型失败") |
|
return False |
|
|
|
ret = rknn.build(do_quantization=False, rknn_batch_size=None) |
|
if ret != 0: |
|
print("构建RKNN模型失败") |
|
return False |
|
|
|
ret = rknn.export_rknn(rknn_model) |
|
if ret != 0: |
|
print("导出RKNN模型失败") |
|
return False |
|
|
|
print(f"成功转换模型: {rknn_model}") |
|
return True |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式') |
|
parser.add_argument('model_type', nargs='?', default='all', |
|
choices=['all', 'transformer', 'vae_decoder'], |
|
help='要转换的模型类型 (默认: all)') |
|
|
|
args = parser.parse_args() |
|
|
|
if args.model_type == 'all': |
|
|
|
for model_type in MODELS.keys(): |
|
if not convert_model(model_type): |
|
print(f"转换 {model_type} 失败") |
|
else: |
|
|
|
if not convert_model(args.model_type): |
|
print(f"转换 {args.model_type} 失败") |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|
|
|
|
|