#!/usr/bin/env python # coding: utf-8 import datetime import argparse from rknn.api import RKNN from sys import exit AUDIO_LENGTH = 645 # 音频长度, 645为10秒 TEXT_LENGTH = 64 # 文本长度(token) # 模型配置 MODELS = { 'transformer': 'transformer.onnx', 'vae_decoder': 'vae_decoder.onnx', } SHAPES = { 'transformer': [ [ [1, AUDIO_LENGTH, 64], # hidden_states [1,], # timestep [2, 1024], # pooled_text [2, TEXT_LENGTH, 1024], # encoder_hidden_states [1, TEXT_LENGTH, 3], # txt_ids [1, AUDIO_LENGTH, 3], # img_ids ], ], 'vae_decoder': [ [ [1, 64, AUDIO_LENGTH], ], ], } QUANTIZE=False detailed_performance_log = True def convert_model(model_type): """转换指定类型的模型到RKNN格式""" if model_type not in MODELS: print(f"错误: 不支持的模型类型 {model_type}") return False onnx_model = MODELS[model_type] rknn_model = onnx_model.replace(".onnx",".rknn") timedate_iso = datetime.datetime.now().isoformat() rknn = RKNN(verbose=True) rknn.config( quantized_dtype='w8a8', quantized_algorithm='normal', quantized_method='channel', quantized_hybrid_level=0, target_platform='rk3588', quant_img_RGB2BGR = False, float_dtype='float16', optimization_level=3, custom_string=f"converted at {timedate_iso}", remove_weight=False, compress_weight=False, inputs_yuv_fmt=None, single_core_mode=False, dynamic_input=SHAPES[model_type], model_pruning=False, op_target=None, quantize_weight=False, remove_reshape=False, sparse_infer=False, enable_flash_attention=False, # disable_rules=['convert_gemm_by_exmatmul'] ) print(f"开始转换 {model_type} 模型...") ret = rknn.load_onnx(model=onnx_model) if ret != 0: print("加载ONNX模型失败") return False ret = rknn.build(do_quantization=False, rknn_batch_size=None) if ret != 0: print("构建RKNN模型失败") return False ret = rknn.export_rknn(rknn_model) if ret != 0: print("导出RKNN模型失败") return False print(f"成功转换模型: {rknn_model}") return True def main(): parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式') parser.add_argument('model_type', nargs='?', default='all', choices=['all', 'transformer', 'vae_decoder'], help='要转换的模型类型 (默认: all)') args = parser.parse_args() if args.model_type == 'all': # 转换所有模型 for model_type in MODELS.keys(): if not convert_model(model_type): print(f"转换 {model_type} 失败") else: # 转换指定模型 if not convert_model(args.model_type): print(f"转换 {args.model_type} 失败") if __name__ == '__main__': main()