happyme531
commited on
Upload 10 files
Browse files- .gitattributes +3 -0
- RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb +3 -0
- RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx +3 -0
- RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn +3 -0
- convert_rknn.py +54 -0
- ea50ffd6-c6fe-11ef-8ff3-1c860b30973e +3 -0
- export_onnx.py +120 -0
- inference.py +210 -0
- rwkv_tokenizer.py +89 -0
- rwkv_vocab_v20230424.txt +0 -0
- ztu_somemodelruntime_rknnlite2.py +509 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ea50ffd6-c6fe-11ef-8ff3-1c860b30973e filter=lfs diff=lfs merge=lfs -text
|
37 |
+
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb filter=lfs diff=lfs merge=lfs -text
|
38 |
+
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn filter=lfs diff=lfs merge=lfs -text
|
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bb71268884738ee0bbc62796b838afd9b460da931589151d949e538cbe58255
|
3 |
+
size 201326592
|
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:080c9153102fe9c2c54e8245411a9ab70360132a13321c3396dd7cca17eca1c4
|
3 |
+
size 305312
|
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:17c375a232e19992bba49459fa7a092ecdb6252841b80850095eb5c6fb4e2bf4
|
3 |
+
size 289121271
|
convert_rknn.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
import datetime
|
5 |
+
from rknn.api import RKNN
|
6 |
+
from sys import exit
|
7 |
+
|
8 |
+
|
9 |
+
ONNX_MODEL = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx"
|
10 |
+
RKNN_MODEL = ONNX_MODEL.replace(".onnx", ".rknn")
|
11 |
+
DATASET = ""
|
12 |
+
QUANTIZE = False
|
13 |
+
detailed_performance_log = True
|
14 |
+
|
15 |
+
timedate_iso = datetime.datetime.now().isoformat()
|
16 |
+
|
17 |
+
rknn = RKNN(verbose=True)
|
18 |
+
rknn.config(
|
19 |
+
# mean_values=[x * 255 for x in [0.485, 0.456, 0.406]],
|
20 |
+
# std_values=[x * 255 for x in [0.229, 0.224, 0.225]],
|
21 |
+
quantized_dtype="w8a8",
|
22 |
+
quantized_algorithm="normal",
|
23 |
+
quantized_method="channel",
|
24 |
+
quantized_hybrid_level=0,
|
25 |
+
target_platform="rk3588",
|
26 |
+
quant_img_RGB2BGR=False,
|
27 |
+
float_dtype="float16",
|
28 |
+
optimization_level=3,
|
29 |
+
custom_string=f"converted at {timedate_iso}",
|
30 |
+
remove_weight=False,
|
31 |
+
compress_weight=False,
|
32 |
+
inputs_yuv_fmt=None,
|
33 |
+
single_core_mode=False,
|
34 |
+
dynamic_input=None,
|
35 |
+
model_pruning=False,
|
36 |
+
op_target=None,
|
37 |
+
quantize_weight=False,
|
38 |
+
remove_reshape=False,
|
39 |
+
sparse_infer=False,
|
40 |
+
enable_flash_attention=False,
|
41 |
+
# 隐藏的参数
|
42 |
+
# disable_rules=[],
|
43 |
+
# sram_prefer=False,
|
44 |
+
# nbuf_prefer=False,
|
45 |
+
# check_data=[],
|
46 |
+
)
|
47 |
+
|
48 |
+
ret = rknn.load_onnx(model=ONNX_MODEL)
|
49 |
+
ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
|
50 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
51 |
+
|
52 |
+
# ret = rknn.init_runtime(target='rk3588',device_id='cbb956772bf5dac9',core_mask=RKNN.NPU_CORE_0,perf_debug=detailed_performance_log)
|
53 |
+
# rknn.eval_perf()
|
54 |
+
# ret = rknn.accuracy_analysis(inputs=['../embeddings.npy','../state.npy','../scale_ratio.npy'], target='rk3588', device_id=device_id)
|
ea50ffd6-c6fe-11ef-8ff3-1c860b30973e
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:faa4dce148b8ed0172ef021b8f732c2eea5dd782caed801dd4727d909d2b9447
|
3 |
+
size 562805760
|
export_onnx.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rwkv_src.rwkv_model import RWKV_RNN, make_chunks
|
2 |
+
import types
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import copy
|
9 |
+
from pathlib import Path
|
10 |
+
import onnx
|
11 |
+
from onnx import shape_inference
|
12 |
+
|
13 |
+
parser = argparse.ArgumentParser(description='Convert model')
|
14 |
+
parser.add_argument('model', type=Path, help='Path to RWKV pth file')
|
15 |
+
parser.add_argument('--chunks', type=int, default=1, help='Number of chunks')
|
16 |
+
parser.add_argument('--ext_embedding', action='store_true', default=False, help='Use external embedding')
|
17 |
+
parser.add_argument('--prefill_model', action='store_true', help='Convert model for sequential prefill')
|
18 |
+
parser.add_argument('--wkv_customop', action='store_true', help='Use custom op for wkv')
|
19 |
+
parser_args = parser.parse_args()
|
20 |
+
|
21 |
+
seq_length = 32 if parser_args.prefill_model else 1
|
22 |
+
|
23 |
+
model_args = types.SimpleNamespace()
|
24 |
+
model_args.USE_CUDA = False
|
25 |
+
model_args.fp16 = False
|
26 |
+
model_args.wkv_customop = parser_args.wkv_customop
|
27 |
+
model_args.USE_EMBEDDING = False if parser_args.ext_embedding else True
|
28 |
+
|
29 |
+
model_args.MODEL_NAME = str(parser_args.model)
|
30 |
+
|
31 |
+
if 'ABC' in model_args.MODEL_NAME or 'MIDI' in model_args.MODEL_NAME or 'x070' in model_args.MODEL_NAME:
|
32 |
+
model_args.RESCALE_LAYER = 0
|
33 |
+
else:
|
34 |
+
model_args.RESCALE_LAYER = 6
|
35 |
+
|
36 |
+
model = make_chunks(parser_args.chunks, model_args) if parser_args.chunks > 1 else RWKV_RNN(model_args)
|
37 |
+
|
38 |
+
if parser_args.prefill_model:
|
39 |
+
model_args.MODEL_NAME = model_args.MODEL_NAME + "_prefill"
|
40 |
+
|
41 |
+
os.path.exists("onnx") or os.mkdir("onnx")
|
42 |
+
|
43 |
+
if type(model) == list:
|
44 |
+
args = model[0].args
|
45 |
+
if not args.USE_EMBEDDING:
|
46 |
+
model[0].emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk1of{len(model)}.emb")
|
47 |
+
args = model[0].args
|
48 |
+
fp16 = args.fp16
|
49 |
+
states = []
|
50 |
+
for i in range(args.n_layer):
|
51 |
+
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
52 |
+
states.append(torch.zeros(args.n_head, args.head_size, args.head_size, dtype=torch.float16 if fp16 else torch.float32))
|
53 |
+
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
54 |
+
if model[0].device is not torch.device('cpu'):
|
55 |
+
states = [i.to(model[0].device) for i in states]
|
56 |
+
|
57 |
+
for i in range(len(model)):
|
58 |
+
dirname = "onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}"
|
59 |
+
os.path.exists(dirname) or os.mkdir(dirname)
|
60 |
+
if i == 0 and args.USE_EMBEDDING:
|
61 |
+
in0 = torch.LongTensor([[1]*seq_length])
|
62 |
+
else:
|
63 |
+
in0 = torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
|
64 |
+
|
65 |
+
if model[0].device is not torch.device('cpu'):
|
66 |
+
in0 = in0.to(model[0].device)
|
67 |
+
inputs = {'in0': in0, 'state': [states[j] for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]}
|
68 |
+
input_names = ['in'] + [f'state{j}_in' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
|
69 |
+
output_names = ['out'] + [f'state{j}_out' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
|
70 |
+
|
71 |
+
if args.wkv_customop:
|
72 |
+
from torch.onnx.symbolic_helper import _get_tensor_sizes
|
73 |
+
from torch.onnx import register_custom_op_symbolic
|
74 |
+
op_name = "rwkv::wkv_chunk" if parser_args.prefill_model else "rwkv::wkv"
|
75 |
+
def onnx_custom_wkv(g, k, v, r, state2, time_first, time_decay):
|
76 |
+
out1, out2 = g.op(op_name, k, v, r, state2, time_first, time_decay, outputs=2)
|
77 |
+
return out1.setType(k.type().with_dtype(torch.float32).with_sizes([seq_length, _get_tensor_sizes(k)[0], 1, args.head_size])),\
|
78 |
+
out2.setType(k.type().with_dtype(torch.float32).with_sizes([1, _get_tensor_sizes(k)[0], args.head_size, args.head_size]))
|
79 |
+
register_custom_op_symbolic(op_name, onnx_custom_wkv, 9)
|
80 |
+
|
81 |
+
torch.onnx.export(model[i], inputs, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", input_names=input_names, output_names=output_names, opset_version=17)
|
82 |
+
shape_inference.infer_shapes_path(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
|
83 |
+
onnx_model = onnx.load(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
|
84 |
+
|
85 |
+
# To make model compatible with other frameworks
|
86 |
+
for initializer in onnx_model.graph.initializer:
|
87 |
+
shape = list(initializer.dims)
|
88 |
+
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
|
89 |
+
onnx_model.graph.value_info.append(value_info)
|
90 |
+
onnx.save_model(onnx_model, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", save_as_external_data=True, all_tensors_to_one_file=True)
|
91 |
+
print(f"onnx model chunk{i} saved to {dirname}" + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
|
92 |
+
|
93 |
+
else:
|
94 |
+
args = model.args
|
95 |
+
if not args.USE_EMBEDDING:
|
96 |
+
model.emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + ".emb")
|
97 |
+
args = model.args
|
98 |
+
fp16 = args.fp16
|
99 |
+
in0 = torch.LongTensor([[1]*seq_length]) if args.USE_EMBEDDING else torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
|
100 |
+
states = []
|
101 |
+
for i in range(model.args.n_layer):
|
102 |
+
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
103 |
+
states.append(torch.zeros(model.args.n_head, model.args.head_size, model.args.head_size, dtype=torch.float16 if fp16 else torch.float32))
|
104 |
+
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
|
105 |
+
if model.device is not torch.device('cpu'):
|
106 |
+
states = [tensor.to(model.device) for tensor in states]
|
107 |
+
inputs = {'in0': in0, 'state': states}
|
108 |
+
input_names = ['in'] + [f'state{i}_in' for i in range(3*model.args.n_layer)]
|
109 |
+
output_names = ['logits'] + [f'state{i}_out' for i in range(3*model.args.n_layer)]
|
110 |
+
torch.onnx.export(model, inputs, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", input_names=input_names, output_names=output_names, opset_version=17)
|
111 |
+
shape_inference.infer_shapes_path("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
|
112 |
+
onnx_model = onnx.load("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
|
113 |
+
|
114 |
+
# To make model compatible with other frameworks
|
115 |
+
for initializer in onnx_model.graph.initializer:
|
116 |
+
shape = list(initializer.dims)
|
117 |
+
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
|
118 |
+
onnx_model.graph.value_info.append(value_info)
|
119 |
+
onnx.save_model(onnx_model, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True)
|
120 |
+
print(f"onnx model saved to onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
|
inference.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# import onnxruntime as ort # Uncomment this line to use onnxruntime
|
2 |
+
import ztu_somemodelruntime_rknnlite2 as ort # Uncomment this line to use rknnlite2
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
from rwkv_tokenizer import RWKV_TOKENIZER
|
6 |
+
import time
|
7 |
+
|
8 |
+
class RWKVModel:
|
9 |
+
def __init__(self, model_path: str, tokenizer_path: str = None, use_external_embedding: bool = False):
|
10 |
+
# 加载ONNX模型
|
11 |
+
session_options = ort.SessionOptions()
|
12 |
+
# session_options.core_mask = 7 # 00000111 使用0,1,2三个核心
|
13 |
+
self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'], session_options=session_options)
|
14 |
+
|
15 |
+
# 打印模型输入信息
|
16 |
+
print("\nModel inputs:")
|
17 |
+
for inp in self.session.get_inputs():
|
18 |
+
print(f"{inp.name}: shape={inp.shape}, type={inp.type}")
|
19 |
+
|
20 |
+
# 获取模型信息
|
21 |
+
self.n_layer = len([x for x in self.session.get_inputs() if 'state' in x.name]) // 3
|
22 |
+
self.n_embd = self.session.get_inputs()[0].shape[-1] if not use_external_embedding else None
|
23 |
+
|
24 |
+
# 从模型中获取状态向量的维度
|
25 |
+
self.state_shapes = {}
|
26 |
+
for inp in self.session.get_inputs():
|
27 |
+
if 'state' in inp.name:
|
28 |
+
self.state_shapes[inp.name] = inp.shape
|
29 |
+
|
30 |
+
print("\nNumber of layers:", self.n_layer)
|
31 |
+
|
32 |
+
# 加载tokenizer
|
33 |
+
if tokenizer_path:
|
34 |
+
self.tokenizer = RWKV_TOKENIZER(tokenizer_path)
|
35 |
+
else:
|
36 |
+
self.tokenizer = None
|
37 |
+
|
38 |
+
# 加载外部embedding(如果需要)
|
39 |
+
self.use_external_embedding = use_external_embedding
|
40 |
+
if use_external_embedding:
|
41 |
+
emb_path = Path(model_path).parent / (Path(model_path).stem + '.emb')
|
42 |
+
self.embedding = np.fromfile(emb_path, dtype=np.float32)
|
43 |
+
# 重新组织embedding数组的形状
|
44 |
+
vocab_size = len(self.embedding) // 768 # 假设embedding维度是768
|
45 |
+
self.embedding = self.embedding.reshape(vocab_size, 768)
|
46 |
+
self.n_embd = 768
|
47 |
+
print(f"\nEmbedding shape: {self.embedding.shape}")
|
48 |
+
|
49 |
+
# 初始化状态
|
50 |
+
self.reset_state()
|
51 |
+
|
52 |
+
def reset_state(self):
|
53 |
+
"""重置所有状态为0"""
|
54 |
+
self.states = []
|
55 |
+
for i in range(self.n_layer * 3):
|
56 |
+
state_name = f'state{i}_in'
|
57 |
+
state_shape = self.state_shapes[state_name]
|
58 |
+
self.states.append(np.zeros(state_shape, dtype=np.float32))
|
59 |
+
|
60 |
+
def _prepare_inputs(self, token_id):
|
61 |
+
"""准备模型输入"""
|
62 |
+
inputs = {}
|
63 |
+
|
64 |
+
# 准备主输入
|
65 |
+
if self.use_external_embedding:
|
66 |
+
# 使用外部embedding
|
67 |
+
embedding = self.embedding[token_id].reshape(1, 1, self.n_embd)
|
68 |
+
inputs['in'] = embedding.astype(np.float32)
|
69 |
+
else:
|
70 |
+
# 使用token id
|
71 |
+
inputs['in'] = np.array([[token_id]], dtype=np.int64)
|
72 |
+
|
73 |
+
# 添加状态
|
74 |
+
for i in range(len(self.states)):
|
75 |
+
inputs[f'state{i}_in'] = self.states[i]
|
76 |
+
|
77 |
+
# 打印输入shape
|
78 |
+
if token_id == 0: # 只打印第一个token的信息
|
79 |
+
print("\nPrepared input shapes:")
|
80 |
+
for k, v in inputs.items():
|
81 |
+
print(f"{k}: shape={v.shape}, type={v.dtype}")
|
82 |
+
|
83 |
+
return inputs
|
84 |
+
|
85 |
+
def forward(self, token_id):
|
86 |
+
"""单步推理"""
|
87 |
+
# 准备输入
|
88 |
+
inputs = self._prepare_inputs(token_id)
|
89 |
+
|
90 |
+
# 运行推理
|
91 |
+
outputs = self.session.run(None, inputs)
|
92 |
+
|
93 |
+
# 打印输出信息(仅第一次)
|
94 |
+
if token_id == 0:
|
95 |
+
print("\nModel outputs:")
|
96 |
+
for i, out in enumerate(outputs):
|
97 |
+
print(f"Output {i}: shape={out.shape}, type={out.dtype}")
|
98 |
+
|
99 |
+
# 更新状态
|
100 |
+
for i in range(len(self.states)):
|
101 |
+
new_state = outputs[i + 1] # 第一个输出是logits
|
102 |
+
# 确保维度匹配
|
103 |
+
if new_state.shape != self.states[i].shape:
|
104 |
+
if token_id == 0:
|
105 |
+
print(f"\nState shape mismatch for state{i}_in:")
|
106 |
+
print(f"Expected: {self.states[i].shape}")
|
107 |
+
print(f"Got: {new_state.shape}")
|
108 |
+
# 处理维度
|
109 |
+
if len(self.states[i].shape) == 2: # (1, 768)
|
110 |
+
new_state = new_state.squeeze(1) # (1, 1, 768) -> (1, 768)
|
111 |
+
elif len(self.states[i].shape) == 3: # (12, 64, 64)
|
112 |
+
new_state = new_state.squeeze(0) # (1, 12, 64, 64) -> (12, 64, 64)
|
113 |
+
self.states[i] = new_state
|
114 |
+
|
115 |
+
return outputs[0] # 返回logits
|
116 |
+
|
117 |
+
def generate(self, prompt: str, max_length: int = 100, temperature: float = 1.0, stop_tokens: set = None):
|
118 |
+
"""生成文本"""
|
119 |
+
if not self.tokenizer:
|
120 |
+
raise ValueError("需要提供tokenizer才能进行文本生成")
|
121 |
+
|
122 |
+
# 编码prompt
|
123 |
+
tokens = self.tokenizer.encode(prompt)
|
124 |
+
generated = list(tokens)
|
125 |
+
|
126 |
+
# 重置状态
|
127 |
+
self.reset_state()
|
128 |
+
|
129 |
+
# 处理prompt
|
130 |
+
print("\nProcessing prompt...", end='', flush=True)
|
131 |
+
t_start = time.time()
|
132 |
+
for token in tokens:
|
133 |
+
logits = self.forward(token)
|
134 |
+
t_prompt = time.time() - t_start
|
135 |
+
print(f" Done. ({len(tokens)} tokens, {t_prompt:.2f}s, {len(tokens)/t_prompt:.2f} tokens/s)")
|
136 |
+
|
137 |
+
# 生成新token
|
138 |
+
print("\nGenerating:", end='', flush=True)
|
139 |
+
t_start = time.time()
|
140 |
+
generated_tokens = 0
|
141 |
+
|
142 |
+
for i in range(max_length):
|
143 |
+
# 获取logits并应用temperature
|
144 |
+
t_token_start = time.time()
|
145 |
+
logits = self.forward(generated[-1])
|
146 |
+
|
147 |
+
# 打印第一次生成的logits信息
|
148 |
+
if i == 0:
|
149 |
+
print(f"\nLogits shape: {logits.shape}")
|
150 |
+
|
151 |
+
# 确保logits是1维的
|
152 |
+
logits = logits.reshape(-1) # 展平成1维
|
153 |
+
|
154 |
+
if temperature > 0:
|
155 |
+
# 应用temperature并计算概率
|
156 |
+
logits = logits / temperature
|
157 |
+
# 减去最大值以避免exp溢出
|
158 |
+
logits = logits - np.max(logits)
|
159 |
+
probs = np.exp(logits)
|
160 |
+
probs = probs / np.sum(probs)
|
161 |
+
next_token = np.random.choice(len(probs), p=probs)
|
162 |
+
else:
|
163 |
+
next_token = np.argmax(logits)
|
164 |
+
|
165 |
+
generated.append(next_token)
|
166 |
+
generated_tokens += 1
|
167 |
+
|
168 |
+
# 检查是否生成了停止标记
|
169 |
+
if stop_tokens and next_token in stop_tokens:
|
170 |
+
break
|
171 |
+
|
172 |
+
# 实时输出新生成的token
|
173 |
+
new_text = self.tokenizer.decode([next_token])
|
174 |
+
print(new_text, end='', flush=True)
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
t_generate = time.time() - t_start
|
179 |
+
print(f"\n\nGeneration finished: {generated_tokens} tokens generated in {t_generate:.2f}s ({generated_tokens/t_generate:.2f} tokens/s)")
|
180 |
+
|
181 |
+
return self.tokenizer.decode(generated)
|
182 |
+
|
183 |
+
def main():
|
184 |
+
import time
|
185 |
+
|
186 |
+
# 使用示例
|
187 |
+
print("Loading model...")
|
188 |
+
t_start = time.time()
|
189 |
+
model = RWKVModel(
|
190 |
+
model_path='RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx',
|
191 |
+
tokenizer_path='rwkv_vocab_v20230424.txt',
|
192 |
+
use_external_embedding=True
|
193 |
+
)
|
194 |
+
print(f"Model loaded in {time.time() - t_start:.2f}s")
|
195 |
+
|
196 |
+
prompt = "Here is a example of Quick Sort algorithm implemented in C++:\n```cpp"
|
197 |
+
print(f"\nPrompt: {prompt}")
|
198 |
+
|
199 |
+
generated_text = model.generate(
|
200 |
+
prompt=prompt,
|
201 |
+
max_length=1024,
|
202 |
+
temperature=0.7,
|
203 |
+
stop_tokens={0, 1, 2, 3} # 特殊token作为停止标记
|
204 |
+
)
|
205 |
+
|
206 |
+
print("\nFull text:")
|
207 |
+
print(generated_text)
|
208 |
+
|
209 |
+
if __name__ == '__main__':
|
210 |
+
main()
|
rwkv_tokenizer.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List,Set,Dict
|
2 |
+
|
3 |
+
class ABCTokenizer():
|
4 |
+
def __init__(self):
|
5 |
+
self.pad_token_id = 0
|
6 |
+
self.bos_token_id = 2
|
7 |
+
self.eos_token_id = 3
|
8 |
+
def encode(self, text):
|
9 |
+
ids = [ord(c) for c in text]
|
10 |
+
return ids
|
11 |
+
def decode(self, ids):
|
12 |
+
txt = ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in ids if idx != self.eos_token_id)
|
13 |
+
return txt
|
14 |
+
|
15 |
+
class RWKV_TOKENIZER():
|
16 |
+
table: List[List[List[bytes]]]
|
17 |
+
good: List[Set[int]]
|
18 |
+
wlen: List[int]
|
19 |
+
def __init__(self, file_name):
|
20 |
+
self.idx2token = {}
|
21 |
+
sorted = [] # must be already sorted
|
22 |
+
lines = open(file_name, "r", encoding="utf-8").readlines()
|
23 |
+
for l in lines:
|
24 |
+
idx = int(l[:l.index(' ')])
|
25 |
+
x = eval(l[l.index(' '):l.rindex(' ')])
|
26 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
27 |
+
assert isinstance(x, bytes)
|
28 |
+
assert len(x) == int(l[l.rindex(' '):])
|
29 |
+
sorted += [x]
|
30 |
+
self.idx2token[idx] = x
|
31 |
+
|
32 |
+
self.token2idx = {}
|
33 |
+
for k, v in self.idx2token.items():
|
34 |
+
self.token2idx[v] = int(k)
|
35 |
+
|
36 |
+
# precompute some tables for fast matching
|
37 |
+
self.table = [[[] for j in range(256)] for i in range(256)]
|
38 |
+
self.good = [set() for i in range(256)]
|
39 |
+
self.wlen = [0 for i in range(256)]
|
40 |
+
|
41 |
+
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
42 |
+
s = sorted[i]
|
43 |
+
if len(s) >= 2:
|
44 |
+
s0 = int(s[0])
|
45 |
+
s1 = int(s[1])
|
46 |
+
self.table[s0][s1] += [s]
|
47 |
+
self.wlen[s0] = max(self.wlen[s0], len(s))
|
48 |
+
self.good[s0].add(s1)
|
49 |
+
|
50 |
+
def encodeBytes(self, src: bytes) -> List[int]:
|
51 |
+
src_len: int = len(src)
|
52 |
+
tokens: List[int] = []
|
53 |
+
i: int = 0
|
54 |
+
while i < src_len:
|
55 |
+
s: bytes = src[i : i + 1]
|
56 |
+
|
57 |
+
if i < src_len - 1:
|
58 |
+
s1: int = int(src[i + 1])
|
59 |
+
s0: int = int(src[i])
|
60 |
+
if s1 in self.good[s0]:
|
61 |
+
sss: bytes = src[i : i + self.wlen[s0]]
|
62 |
+
try:
|
63 |
+
s = next(filter(sss.startswith, self.table[s0][s1]))
|
64 |
+
except:
|
65 |
+
pass
|
66 |
+
tokens.append(self.token2idx[s])
|
67 |
+
i += len(s)
|
68 |
+
|
69 |
+
return tokens
|
70 |
+
|
71 |
+
def decodeBytes(self, tokens):
|
72 |
+
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
73 |
+
|
74 |
+
def encode(self, src: str):
|
75 |
+
return self.encodeBytes(src.encode("utf-8"))
|
76 |
+
|
77 |
+
def decode(self, tokens):
|
78 |
+
return self.decodeBytes(tokens).decode('utf-8')
|
79 |
+
|
80 |
+
def printTokens(self, tokens):
|
81 |
+
for i in tokens:
|
82 |
+
s = self.idx2token[i]
|
83 |
+
try:
|
84 |
+
s = s.decode('utf-8')
|
85 |
+
except:
|
86 |
+
pass
|
87 |
+
print(f'{repr(s)}{i}', end=' ')
|
88 |
+
# print(repr(s), i)
|
89 |
+
print()
|
rwkv_vocab_v20230424.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
ztu_somemodelruntime_rknnlite2.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 模块级常量和函数
|
2 |
+
from rknnlite.api import RKNNLite
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import warnings
|
6 |
+
import logging
|
7 |
+
from typing import List, Dict, Union, Optional
|
8 |
+
|
9 |
+
# 配置日志
|
10 |
+
logger = logging.getLogger("somemodelruntime_rknnlite2")
|
11 |
+
logger.setLevel(logging.ERROR) # 默认只输出错误信息
|
12 |
+
if not logger.handlers:
|
13 |
+
handler = logging.StreamHandler()
|
14 |
+
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
15 |
+
logger.addHandler(handler)
|
16 |
+
|
17 |
+
# ONNX Runtime日志级别到Python logging级别的映射
|
18 |
+
_LOGGING_LEVEL_MAP = {
|
19 |
+
0: logging.DEBUG, # Verbose
|
20 |
+
1: logging.INFO, # Info
|
21 |
+
2: logging.WARNING, # Warning
|
22 |
+
3: logging.ERROR, # Error
|
23 |
+
4: logging.CRITICAL # Fatal
|
24 |
+
}
|
25 |
+
|
26 |
+
def set_default_logger_severity(level: int) -> None:
|
27 |
+
"""
|
28 |
+
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
|
29 |
+
|
30 |
+
Args:
|
31 |
+
level: 日志级别(0-4)
|
32 |
+
"""
|
33 |
+
if level not in _LOGGING_LEVEL_MAP:
|
34 |
+
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
|
35 |
+
logger.setLevel(_LOGGING_LEVEL_MAP[level])
|
36 |
+
|
37 |
+
def set_default_logger_verbosity(level: int) -> None:
|
38 |
+
"""
|
39 |
+
Sets the default logging verbosity level. To activate the verbose log,
|
40 |
+
you need to set the default logging severity to 0:Verbose level.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
level: 日志级别(0-4)
|
44 |
+
"""
|
45 |
+
set_default_logger_severity(level)
|
46 |
+
|
47 |
+
# NPU核心模式常量
|
48 |
+
NPU_CORE_AUTO = 0 # 自动选择
|
49 |
+
NPU_CORE_0 = 1 # 使用核心0
|
50 |
+
NPU_CORE_1 = 2 # 使用核心1
|
51 |
+
NPU_CORE_2 = 4 # 使用核心2
|
52 |
+
NPU_CORE_0_1 = 3 # 使用核心0和1
|
53 |
+
NPU_CORE_0_1_2 = 7 # 使用所有核心
|
54 |
+
NPU_CORE_ALL = 0xffff # 使用所有核心
|
55 |
+
|
56 |
+
# RKNN tensor type到numpy dtype的映射
|
57 |
+
RKNN_DTYPE_MAP = {
|
58 |
+
0: np.float32, # RKNN_TENSOR_FLOAT32
|
59 |
+
1: np.float16, # RKNN_TENSOR_FLOAT16
|
60 |
+
2: np.int8, # RKNN_TENSOR_INT8
|
61 |
+
3: np.uint8, # RKNN_TENSOR_UINT8
|
62 |
+
4: np.int16, # RKNN_TENSOR_INT16
|
63 |
+
5: np.uint16, # RKNN_TENSOR_UINT16
|
64 |
+
6: np.int32, # RKNN_TENSOR_INT32
|
65 |
+
7: np.uint32, # RKNN_TENSOR_UINT32
|
66 |
+
8: np.int64, # RKNN_TENSOR_INT64
|
67 |
+
9: bool, # RKNN_TENSOR_BOOL
|
68 |
+
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
|
69 |
+
}
|
70 |
+
|
71 |
+
def get_available_providers() -> List[str]:
|
72 |
+
"""
|
73 |
+
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider"]
|
77 |
+
"""
|
78 |
+
return ["CPUExecutionProvider"]
|
79 |
+
|
80 |
+
def get_version_info() -> Dict[str, str]:
|
81 |
+
"""
|
82 |
+
获取版本信息
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
dict: 包含API和驱动版本信息的字典
|
86 |
+
"""
|
87 |
+
runtime = RKNNLite()
|
88 |
+
version = runtime.get_sdk_version()
|
89 |
+
return {
|
90 |
+
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
|
91 |
+
"driver_version": version.split('\n')[3].split(': ')[1]
|
92 |
+
}
|
93 |
+
|
94 |
+
class IOTensor:
|
95 |
+
"""输入/输出张量的信息封装类"""
|
96 |
+
def __init__(self, name, shape, type=None):
|
97 |
+
self.name = name.decode() if isinstance(name, bytes) else name
|
98 |
+
self.shape = shape
|
99 |
+
self.type = type
|
100 |
+
|
101 |
+
def __str__(self):
|
102 |
+
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
|
103 |
+
|
104 |
+
class SessionOptions:
|
105 |
+
"""会话选项类"""
|
106 |
+
def __init__(self):
|
107 |
+
self.async_mode = False # 是否使用异步模式
|
108 |
+
self.core_mask = 0 # NPU核心选择
|
109 |
+
self.perf_debug = False # 是否启用性能分析
|
110 |
+
|
111 |
+
class InferenceSession:
|
112 |
+
"""
|
113 |
+
RKNNLite运行时封装类,API风格类似ONNX Runtime
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(self, model_path: str, verbose: bool = False, session_options: Optional[SessionOptions] = None, **kwargs):
|
117 |
+
"""
|
118 |
+
初始化运行时并加载模型
|
119 |
+
|
120 |
+
Args:
|
121 |
+
model_path: 模型文件路径(.rknn或.onnx)
|
122 |
+
verbose: 是否打印详细日志
|
123 |
+
session_options: 会话选项
|
124 |
+
**kwargs: 其他初始化参数
|
125 |
+
"""
|
126 |
+
# 只在verbose=True时开启详细日志
|
127 |
+
if verbose:
|
128 |
+
set_default_logger_severity(0) # Verbose
|
129 |
+
|
130 |
+
self.model_path = self._process_model_path(model_path)
|
131 |
+
self.runtime = RKNNLite(verbose=verbose)
|
132 |
+
|
133 |
+
# 加载模型
|
134 |
+
logger.debug(f"正在加载模型: {self.model_path}")
|
135 |
+
ret = self.runtime.load_rknn(self.model_path)
|
136 |
+
if ret != 0:
|
137 |
+
logger.error(f"加载RKNN模型失败: {self.model_path}")
|
138 |
+
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
|
139 |
+
logger.debug("模型加载成功")
|
140 |
+
|
141 |
+
# 应用会话选项
|
142 |
+
options = session_options or SessionOptions()
|
143 |
+
|
144 |
+
# 初始化运行时
|
145 |
+
logger.debug("正在初始化运行时环境")
|
146 |
+
ret = self.runtime.init_runtime(
|
147 |
+
async_mode=options.async_mode,
|
148 |
+
core_mask=options.core_mask
|
149 |
+
)
|
150 |
+
if ret != 0:
|
151 |
+
logger.error("初始化运行时环境失败")
|
152 |
+
raise RuntimeError('初始化运行时环境失败')
|
153 |
+
logger.debug("运行时环境初始化成功")
|
154 |
+
|
155 |
+
# 获取输入输出信息
|
156 |
+
self._init_io_info()
|
157 |
+
|
158 |
+
# 保存选项
|
159 |
+
self.options = options
|
160 |
+
|
161 |
+
def get_performance_info(self) -> Dict[str, float]:
|
162 |
+
"""
|
163 |
+
获取性能信息
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
dict: 包含性能信息的字典
|
167 |
+
"""
|
168 |
+
if not self.options.perf_debug:
|
169 |
+
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
|
170 |
+
|
171 |
+
perf = self.runtime.rknn_runtime.get_run_perf()
|
172 |
+
return {
|
173 |
+
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
|
174 |
+
}
|
175 |
+
|
176 |
+
def set_core_mask(self, core_mask: int) -> None:
|
177 |
+
"""
|
178 |
+
设置NPU核心使用模式
|
179 |
+
|
180 |
+
Args:
|
181 |
+
core_mask: NPU核心掩码,使用NPU_CORE_*常量
|
182 |
+
"""
|
183 |
+
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
|
184 |
+
if ret != 0:
|
185 |
+
raise RuntimeError("设置NPU核心模式失败")
|
186 |
+
|
187 |
+
def _process_model_path(self, model_path):
|
188 |
+
"""处理模型路径,支持.onnx和.rknn文件"""
|
189 |
+
if not os.path.exists(model_path):
|
190 |
+
logger.error(f"模型文件不存在: {model_path}")
|
191 |
+
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
192 |
+
|
193 |
+
# 如果是ONNX文件
|
194 |
+
if model_path.lower().endswith('.onnx'):
|
195 |
+
logger.warning(
|
196 |
+
"检测到ONNX模型文件。注意:SomeModelRuntime不会自动转换ONNX到RKNN。"
|
197 |
+
"请先使用RKNN Toolkit转换模型。"
|
198 |
+
"现在尝试加载同名的.rknn文件。"
|
199 |
+
)
|
200 |
+
# 构造RKNN文件路径
|
201 |
+
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
|
202 |
+
if not os.path.exists(rknn_path):
|
203 |
+
logger.error(f"RKNN模型文件不存在: {rknn_path}")
|
204 |
+
raise FileNotFoundError(
|
205 |
+
f"RKNN模型文件不存在: {rknn_path}\n"
|
206 |
+
"请先使用RKNN Toolkit将ONNX模型转换为RKNN格式。"
|
207 |
+
)
|
208 |
+
return rknn_path
|
209 |
+
|
210 |
+
return model_path
|
211 |
+
|
212 |
+
def _convert_nhwc_to_nchw(self, shape):
|
213 |
+
"""将NHWC格式的shape转换为NCHW格式"""
|
214 |
+
if len(shape) == 4:
|
215 |
+
# NHWC -> NCHW
|
216 |
+
n, h, w, c = shape
|
217 |
+
return [n, c, h, w]
|
218 |
+
return shape
|
219 |
+
|
220 |
+
def _init_io_info(self):
|
221 |
+
"""初始化模型的输入输出信息"""
|
222 |
+
runtime = self.runtime.rknn_runtime
|
223 |
+
|
224 |
+
# 获取输入输出数量
|
225 |
+
n_input, n_output = runtime.get_in_out_num()
|
226 |
+
|
227 |
+
# 获取输入信息
|
228 |
+
self.input_tensors = []
|
229 |
+
for i in range(n_input):
|
230 |
+
attr = runtime.get_tensor_attr(i)
|
231 |
+
shape = [attr.dims[j] for j in range(attr.n_dims)]
|
232 |
+
# 对四维输入进行NHWC到NCHW的转换
|
233 |
+
shape = self._convert_nhwc_to_nchw(shape)
|
234 |
+
# 获取dtype
|
235 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
236 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
237 |
+
self.input_tensors.append(tensor)
|
238 |
+
|
239 |
+
# 获取输出信息
|
240 |
+
self.output_tensors = []
|
241 |
+
for i in range(n_output):
|
242 |
+
attr = runtime.get_tensor_attr(i, is_output=True)
|
243 |
+
shape = runtime.get_output_shape(i)
|
244 |
+
# 获取dtype
|
245 |
+
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
|
246 |
+
tensor = IOTensor(attr.name, shape, dtype)
|
247 |
+
self.output_tensors.append(tensor)
|
248 |
+
|
249 |
+
def get_inputs(self):
|
250 |
+
"""
|
251 |
+
获取模型输入信息
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
list: 包含输入信息的列表
|
255 |
+
"""
|
256 |
+
return self.input_tensors
|
257 |
+
|
258 |
+
def get_outputs(self):
|
259 |
+
"""
|
260 |
+
获取模型输出信息
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
list: 包含输出信息的列表
|
264 |
+
"""
|
265 |
+
return self.output_tensors
|
266 |
+
|
267 |
+
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
|
268 |
+
"""
|
269 |
+
执行模型推理
|
270 |
+
|
271 |
+
Args:
|
272 |
+
output_names: 输出节点名称列表,指定需要返回哪些输出
|
273 |
+
input_feed: 输入数据字典或列表
|
274 |
+
data_format: 输入数据格式,"nchw"或"nhwc"
|
275 |
+
**kwargs: 其他运行时参数
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
|
279 |
+
"""
|
280 |
+
if input_feed is None:
|
281 |
+
logger.error("input_feed不能为None")
|
282 |
+
raise ValueError("input_feed不能为None")
|
283 |
+
|
284 |
+
# 准备输入数据
|
285 |
+
if isinstance(input_feed, dict):
|
286 |
+
# 如果是字典,按照模型输入顺序排列
|
287 |
+
inputs = []
|
288 |
+
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
|
289 |
+
for tensor in self.input_tensors:
|
290 |
+
if tensor.name not in input_feed:
|
291 |
+
raise ValueError(f"缺少输入: {tensor.name}")
|
292 |
+
inputs.append(input_feed[tensor.name])
|
293 |
+
elif isinstance(input_feed, (list, tuple)):
|
294 |
+
# 如果是列表,确保长度匹配
|
295 |
+
if len(input_feed) != len(self.input_tensors):
|
296 |
+
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
|
297 |
+
inputs = list(input_feed)
|
298 |
+
else:
|
299 |
+
logger.error("input_feed必须是字典或列表类型")
|
300 |
+
raise ValueError("input_feed必须是字典或列表类型")
|
301 |
+
|
302 |
+
# 执行推理
|
303 |
+
try:
|
304 |
+
logger.debug("开始执行推理")
|
305 |
+
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
|
306 |
+
|
307 |
+
# 如果没有指定output_names,返回所有输出
|
308 |
+
if output_names is None:
|
309 |
+
return all_outputs
|
310 |
+
|
311 |
+
# 获取指定的输出
|
312 |
+
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
|
313 |
+
selected_outputs = []
|
314 |
+
for name in output_names:
|
315 |
+
if name not in output_map:
|
316 |
+
raise ValueError(f"未找到输出节点: {name}")
|
317 |
+
selected_outputs.append(all_outputs[output_map[name]])
|
318 |
+
|
319 |
+
return selected_outputs
|
320 |
+
|
321 |
+
except Exception as e:
|
322 |
+
logger.error(f"推理执行失败: {str(e)}")
|
323 |
+
raise RuntimeError(f"推理执行失败: {str(e)}")
|
324 |
+
|
325 |
+
def close(self):
|
326 |
+
"""
|
327 |
+
关闭会话,释放资源
|
328 |
+
"""
|
329 |
+
if self.runtime is not None:
|
330 |
+
logger.info("正在释放运行时资源")
|
331 |
+
self.runtime.release()
|
332 |
+
self.runtime = None
|
333 |
+
|
334 |
+
def __enter__(self):
|
335 |
+
return self
|
336 |
+
|
337 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
338 |
+
self.close()
|
339 |
+
|
340 |
+
def end_profiling(self) -> Optional[str]:
|
341 |
+
"""
|
342 |
+
结束性能分析的存根方法
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
Optional[str]: None
|
346 |
+
"""
|
347 |
+
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
348 |
+
return None
|
349 |
+
|
350 |
+
def get_profiling_start_time_ns(self) -> int:
|
351 |
+
"""
|
352 |
+
获取性能分析开始时间的存根方法
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
int: 0
|
356 |
+
"""
|
357 |
+
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
358 |
+
return 0
|
359 |
+
|
360 |
+
def get_modelmeta(self) -> Dict[str, str]:
|
361 |
+
"""
|
362 |
+
获取模型元数据的存根方法
|
363 |
+
|
364 |
+
Returns:
|
365 |
+
Dict[str, str]: 空字典
|
366 |
+
"""
|
367 |
+
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
368 |
+
return {}
|
369 |
+
|
370 |
+
def get_session_options(self) -> SessionOptions:
|
371 |
+
"""
|
372 |
+
获取会话选项
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
SessionOptions: 当前会话选项
|
376 |
+
"""
|
377 |
+
return self.options
|
378 |
+
|
379 |
+
def get_providers(self) -> List[str]:
|
380 |
+
"""
|
381 |
+
获取当前使用的providers的存根方法
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
List[str]: ["CPUExecutionProvider"]
|
385 |
+
"""
|
386 |
+
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
|
387 |
+
return ["CPUExecutionProvider"]
|
388 |
+
|
389 |
+
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
|
390 |
+
"""
|
391 |
+
获取provider选项的存根方法
|
392 |
+
|
393 |
+
Returns:
|
394 |
+
Dict[str, Dict[str, str]]: 空字典
|
395 |
+
"""
|
396 |
+
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
397 |
+
return {}
|
398 |
+
|
399 |
+
def get_session_config(self) -> Dict[str, str]:
|
400 |
+
"""
|
401 |
+
获取会话配置的存根方法
|
402 |
+
|
403 |
+
Returns:
|
404 |
+
Dict[str, str]: 空字典
|
405 |
+
"""
|
406 |
+
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
407 |
+
return {}
|
408 |
+
|
409 |
+
def get_session_state(self) -> Dict[str, str]:
|
410 |
+
"""
|
411 |
+
获取会话状态的存根方法
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
Dict[str, str]: 空字典
|
415 |
+
"""
|
416 |
+
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
417 |
+
return {}
|
418 |
+
|
419 |
+
def set_session_config(self, config: Dict[str, str]) -> None:
|
420 |
+
"""
|
421 |
+
设置会话配置的存根方法
|
422 |
+
|
423 |
+
Args:
|
424 |
+
config: 会话配置字典
|
425 |
+
"""
|
426 |
+
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
427 |
+
|
428 |
+
def get_memory_info(self) -> Dict[str, int]:
|
429 |
+
"""
|
430 |
+
获取内存使用信息的存根方法
|
431 |
+
|
432 |
+
Returns:
|
433 |
+
Dict[str, int]: 空字典
|
434 |
+
"""
|
435 |
+
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
436 |
+
return {}
|
437 |
+
|
438 |
+
def set_memory_pattern(self, enable: bool) -> None:
|
439 |
+
"""
|
440 |
+
设置内存模式的存根方法
|
441 |
+
|
442 |
+
Args:
|
443 |
+
enable: 是否启用内存模式
|
444 |
+
"""
|
445 |
+
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
446 |
+
|
447 |
+
def disable_memory_pattern(self) -> None:
|
448 |
+
"""
|
449 |
+
禁用内存模式的存根方法
|
450 |
+
"""
|
451 |
+
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
452 |
+
|
453 |
+
def get_optimization_level(self) -> int:
|
454 |
+
"""
|
455 |
+
获取优化级别的存根方法
|
456 |
+
|
457 |
+
Returns:
|
458 |
+
int: 0
|
459 |
+
"""
|
460 |
+
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
461 |
+
return 0
|
462 |
+
|
463 |
+
def set_optimization_level(self, level: int) -> None:
|
464 |
+
"""
|
465 |
+
设置优化级别的存根方法
|
466 |
+
|
467 |
+
Args:
|
468 |
+
level: 优化级别
|
469 |
+
"""
|
470 |
+
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
471 |
+
|
472 |
+
def get_model_metadata(self) -> Dict[str, str]:
|
473 |
+
"""
|
474 |
+
获取模型元数据的存根方法(与get_modelmeta不同的接口)
|
475 |
+
|
476 |
+
Returns:
|
477 |
+
Dict[str, str]: 空字典
|
478 |
+
"""
|
479 |
+
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
480 |
+
return {}
|
481 |
+
|
482 |
+
def get_model_path(self) -> str:
|
483 |
+
"""
|
484 |
+
获取模型路径
|
485 |
+
|
486 |
+
Returns:
|
487 |
+
str: 模型文件路径
|
488 |
+
"""
|
489 |
+
return self.model_path
|
490 |
+
|
491 |
+
def get_input_type_info(self) -> List[Dict[str, str]]:
|
492 |
+
"""
|
493 |
+
获取输入类型信息的存根方法
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
List[Dict[str, str]]: 空列表
|
497 |
+
"""
|
498 |
+
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
499 |
+
return []
|
500 |
+
|
501 |
+
def get_output_type_info(self) -> List[Dict[str, str]]:
|
502 |
+
"""
|
503 |
+
获取输出类型信息的存根方法
|
504 |
+
|
505 |
+
Returns:
|
506 |
+
List[Dict[str, str]]: 空列表
|
507 |
+
"""
|
508 |
+
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
|
509 |
+
return []
|