File size: 3,159 Bytes
f3a1217 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
#!/usr/bin/env python
# coding: utf-8
import datetime
import argparse
from rknn.api import RKNN
from sys import exit
AUDIO_LENGTH = 645 # 音频长度, 645为10秒
TEXT_LENGTH = 64 # 文本长度(token)
# 模型配置
MODELS = {
'transformer': 'transformer.onnx',
'vae_decoder': 'vae_decoder.onnx',
}
SHAPES = {
'transformer': [
[
[1, AUDIO_LENGTH, 64], # hidden_states
[1,], # timestep
[2, 1024], # pooled_text
[2, TEXT_LENGTH, 1024], # encoder_hidden_states
[1, TEXT_LENGTH, 3], # txt_ids
[1, AUDIO_LENGTH, 3], # img_ids
],
],
'vae_decoder': [
[
[1, 64, AUDIO_LENGTH],
],
],
}
QUANTIZE=False
detailed_performance_log = True
def convert_model(model_type):
"""转换指定类型的模型到RKNN格式"""
if model_type not in MODELS:
print(f"错误: 不支持的模型类型 {model_type}")
return False
onnx_model = MODELS[model_type]
rknn_model = onnx_model.replace(".onnx",".rknn")
timedate_iso = datetime.datetime.now().isoformat()
rknn = RKNN(verbose=True)
rknn.config(
quantized_dtype='w8a8',
quantized_algorithm='normal',
quantized_method='channel',
quantized_hybrid_level=0,
target_platform='rk3588',
quant_img_RGB2BGR = False,
float_dtype='float16',
optimization_level=3,
custom_string=f"converted at {timedate_iso}",
remove_weight=False,
compress_weight=False,
inputs_yuv_fmt=None,
single_core_mode=False,
dynamic_input=SHAPES[model_type],
model_pruning=False,
op_target=None,
quantize_weight=False,
remove_reshape=False,
sparse_infer=False,
enable_flash_attention=False,
# disable_rules=['convert_gemm_by_exmatmul']
)
print(f"开始转换 {model_type} 模型...")
ret = rknn.load_onnx(model=onnx_model)
if ret != 0:
print("加载ONNX模型失败")
return False
ret = rknn.build(do_quantization=False, rknn_batch_size=None)
if ret != 0:
print("构建RKNN模型失败")
return False
ret = rknn.export_rknn(rknn_model)
if ret != 0:
print("导出RKNN模型失败")
return False
print(f"成功转换模型: {rknn_model}")
return True
def main():
parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
parser.add_argument('model_type', nargs='?', default='all',
choices=['all', 'transformer', 'vae_decoder'],
help='要转换的模型类型 (默认: all)')
args = parser.parse_args()
if args.model_type == 'all':
# 转换所有模型
for model_type in MODELS.keys():
if not convert_model(model_type):
print(f"转换 {model_type} 失败")
else:
# 转换指定模型
if not convert_model(args.model_type):
print(f"转换 {args.model_type} 失败")
if __name__ == '__main__':
main()
|