File size: 3,159 Bytes
f3a1217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
# coding: utf-8

import datetime
import argparse
from rknn.api import RKNN
from sys import exit

AUDIO_LENGTH = 645 # 音频长度, 645为10秒
TEXT_LENGTH = 64 # 文本长度(token)

# 模型配置
MODELS = {
    'transformer': 'transformer.onnx',
    'vae_decoder': 'vae_decoder.onnx',
}

SHAPES = {
    'transformer': [
        [
            [1, AUDIO_LENGTH, 64],  # hidden_states
            [1,],  # timestep
            [2, 1024],  # pooled_text
            [2, TEXT_LENGTH, 1024],  # encoder_hidden_states
            [1, TEXT_LENGTH, 3],  # txt_ids
            [1, AUDIO_LENGTH, 3],  # img_ids
        ],
    ],
    'vae_decoder': [
        [
            [1, 64, AUDIO_LENGTH],
        ],
    ],
}

QUANTIZE=False
detailed_performance_log = True

def convert_model(model_type):
    """转换指定类型的模型到RKNN格式"""
    if model_type not in MODELS:
        print(f"错误: 不支持的模型类型 {model_type}")
        return False
        
    onnx_model = MODELS[model_type]
    rknn_model = onnx_model.replace(".onnx",".rknn")
    
    timedate_iso = datetime.datetime.now().isoformat()
    
    rknn = RKNN(verbose=True)
    rknn.config(
        quantized_dtype='w8a8',
        quantized_algorithm='normal',
        quantized_method='channel',
        quantized_hybrid_level=0,
        target_platform='rk3588',
        quant_img_RGB2BGR = False,
        float_dtype='float16',
        optimization_level=3,
        custom_string=f"converted at {timedate_iso}",
        remove_weight=False,
        compress_weight=False,
        inputs_yuv_fmt=None,
        single_core_mode=False,
        dynamic_input=SHAPES[model_type],
        model_pruning=False,
        op_target=None,
        quantize_weight=False,
        remove_reshape=False,
        sparse_infer=False,
        enable_flash_attention=False,
        #  disable_rules=['convert_gemm_by_exmatmul']
    )

    print(f"开始转换 {model_type} 模型...")
    ret = rknn.load_onnx(model=onnx_model)
    if ret != 0:
        print("加载ONNX模型失败")
        return False
        
    ret = rknn.build(do_quantization=False, rknn_batch_size=None)
    if ret != 0:
        print("构建RKNN模型失败")
        return False
        
    ret = rknn.export_rknn(rknn_model)
    if ret != 0:
        print("导出RKNN模型失败")
        return False
        
    print(f"成功转换模型: {rknn_model}")
    return True

def main():
    parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
    parser.add_argument('model_type', nargs='?', default='all',
                      choices=['all', 'transformer', 'vae_decoder'],
                      help='要转换的模型类型 (默认: all)')
    
    args = parser.parse_args()
    
    if args.model_type == 'all':
        # 转换所有模型
        for model_type in MODELS.keys():
            if not convert_model(model_type):
                print(f"转换 {model_type} 失败")
    else:
        # 转换指定模型
        if not convert_model(args.model_type):
            print(f"转换 {args.model_type} 失败")

if __name__ == '__main__':
    main()