happyme531 commited on
Commit
0053ecb
·
verified ·
1 Parent(s): 26eaabd

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ea50ffd6-c6fe-11ef-8ff3-1c860b30973e filter=lfs diff=lfs merge=lfs -text
37
+ RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb filter=lfs diff=lfs merge=lfs -text
38
+ RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn filter=lfs diff=lfs merge=lfs -text
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.emb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bb71268884738ee0bbc62796b838afd9b460da931589151d949e538cbe58255
3
+ size 201326592
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:080c9153102fe9c2c54e8245411a9ab70360132a13321c3396dd7cca17eca1c4
3
+ size 305312
RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17c375a232e19992bba49459fa7a092ecdb6252841b80850095eb5c6fb4e2bf4
3
+ size 289121271
convert_rknn.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import datetime
5
+ from rknn.api import RKNN
6
+ from sys import exit
7
+
8
+
9
+ ONNX_MODEL = "RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx"
10
+ RKNN_MODEL = ONNX_MODEL.replace(".onnx", ".rknn")
11
+ DATASET = ""
12
+ QUANTIZE = False
13
+ detailed_performance_log = True
14
+
15
+ timedate_iso = datetime.datetime.now().isoformat()
16
+
17
+ rknn = RKNN(verbose=True)
18
+ rknn.config(
19
+ # mean_values=[x * 255 for x in [0.485, 0.456, 0.406]],
20
+ # std_values=[x * 255 for x in [0.229, 0.224, 0.225]],
21
+ quantized_dtype="w8a8",
22
+ quantized_algorithm="normal",
23
+ quantized_method="channel",
24
+ quantized_hybrid_level=0,
25
+ target_platform="rk3588",
26
+ quant_img_RGB2BGR=False,
27
+ float_dtype="float16",
28
+ optimization_level=3,
29
+ custom_string=f"converted at {timedate_iso}",
30
+ remove_weight=False,
31
+ compress_weight=False,
32
+ inputs_yuv_fmt=None,
33
+ single_core_mode=False,
34
+ dynamic_input=None,
35
+ model_pruning=False,
36
+ op_target=None,
37
+ quantize_weight=False,
38
+ remove_reshape=False,
39
+ sparse_infer=False,
40
+ enable_flash_attention=False,
41
+ # 隐藏的参数
42
+ # disable_rules=[],
43
+ # sram_prefer=False,
44
+ # nbuf_prefer=False,
45
+ # check_data=[],
46
+ )
47
+
48
+ ret = rknn.load_onnx(model=ONNX_MODEL)
49
+ ret = rknn.build(do_quantization=QUANTIZE, dataset=DATASET, rknn_batch_size=None)
50
+ ret = rknn.export_rknn(RKNN_MODEL)
51
+
52
+ # ret = rknn.init_runtime(target='rk3588',device_id='cbb956772bf5dac9',core_mask=RKNN.NPU_CORE_0,perf_debug=detailed_performance_log)
53
+ # rknn.eval_perf()
54
+ # ret = rknn.accuracy_analysis(inputs=['../embeddings.npy','../state.npy','../scale_ratio.npy'], target='rk3588', device_id=device_id)
ea50ffd6-c6fe-11ef-8ff3-1c860b30973e ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faa4dce148b8ed0172ef021b8f732c2eea5dd782caed801dd4727d909d2b9447
3
+ size 562805760
export_onnx.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rwkv_src.rwkv_model import RWKV_RNN, make_chunks
2
+ import types
3
+ import os
4
+ import torch
5
+ import numpy as np
6
+ import argparse
7
+ import json
8
+ import copy
9
+ from pathlib import Path
10
+ import onnx
11
+ from onnx import shape_inference
12
+
13
+ parser = argparse.ArgumentParser(description='Convert model')
14
+ parser.add_argument('model', type=Path, help='Path to RWKV pth file')
15
+ parser.add_argument('--chunks', type=int, default=1, help='Number of chunks')
16
+ parser.add_argument('--ext_embedding', action='store_true', default=False, help='Use external embedding')
17
+ parser.add_argument('--prefill_model', action='store_true', help='Convert model for sequential prefill')
18
+ parser.add_argument('--wkv_customop', action='store_true', help='Use custom op for wkv')
19
+ parser_args = parser.parse_args()
20
+
21
+ seq_length = 32 if parser_args.prefill_model else 1
22
+
23
+ model_args = types.SimpleNamespace()
24
+ model_args.USE_CUDA = False
25
+ model_args.fp16 = False
26
+ model_args.wkv_customop = parser_args.wkv_customop
27
+ model_args.USE_EMBEDDING = False if parser_args.ext_embedding else True
28
+
29
+ model_args.MODEL_NAME = str(parser_args.model)
30
+
31
+ if 'ABC' in model_args.MODEL_NAME or 'MIDI' in model_args.MODEL_NAME or 'x070' in model_args.MODEL_NAME:
32
+ model_args.RESCALE_LAYER = 0
33
+ else:
34
+ model_args.RESCALE_LAYER = 6
35
+
36
+ model = make_chunks(parser_args.chunks, model_args) if parser_args.chunks > 1 else RWKV_RNN(model_args)
37
+
38
+ if parser_args.prefill_model:
39
+ model_args.MODEL_NAME = model_args.MODEL_NAME + "_prefill"
40
+
41
+ os.path.exists("onnx") or os.mkdir("onnx")
42
+
43
+ if type(model) == list:
44
+ args = model[0].args
45
+ if not args.USE_EMBEDDING:
46
+ model[0].emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk1of{len(model)}.emb")
47
+ args = model[0].args
48
+ fp16 = args.fp16
49
+ states = []
50
+ for i in range(args.n_layer):
51
+ states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
52
+ states.append(torch.zeros(args.n_head, args.head_size, args.head_size, dtype=torch.float16 if fp16 else torch.float32))
53
+ states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
54
+ if model[0].device is not torch.device('cpu'):
55
+ states = [i.to(model[0].device) for i in states]
56
+
57
+ for i in range(len(model)):
58
+ dirname = "onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}"
59
+ os.path.exists(dirname) or os.mkdir(dirname)
60
+ if i == 0 and args.USE_EMBEDDING:
61
+ in0 = torch.LongTensor([[1]*seq_length])
62
+ else:
63
+ in0 = torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
64
+
65
+ if model[0].device is not torch.device('cpu'):
66
+ in0 = in0.to(model[0].device)
67
+ inputs = {'in0': in0, 'state': [states[j] for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]}
68
+ input_names = ['in'] + [f'state{j}_in' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
69
+ output_names = ['out'] + [f'state{j}_out' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
70
+
71
+ if args.wkv_customop:
72
+ from torch.onnx.symbolic_helper import _get_tensor_sizes
73
+ from torch.onnx import register_custom_op_symbolic
74
+ op_name = "rwkv::wkv_chunk" if parser_args.prefill_model else "rwkv::wkv"
75
+ def onnx_custom_wkv(g, k, v, r, state2, time_first, time_decay):
76
+ out1, out2 = g.op(op_name, k, v, r, state2, time_first, time_decay, outputs=2)
77
+ return out1.setType(k.type().with_dtype(torch.float32).with_sizes([seq_length, _get_tensor_sizes(k)[0], 1, args.head_size])),\
78
+ out2.setType(k.type().with_dtype(torch.float32).with_sizes([1, _get_tensor_sizes(k)[0], args.head_size, args.head_size]))
79
+ register_custom_op_symbolic(op_name, onnx_custom_wkv, 9)
80
+
81
+ torch.onnx.export(model[i], inputs, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", input_names=input_names, output_names=output_names, opset_version=17)
82
+ shape_inference.infer_shapes_path(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
83
+ onnx_model = onnx.load(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
84
+
85
+ # To make model compatible with other frameworks
86
+ for initializer in onnx_model.graph.initializer:
87
+ shape = list(initializer.dims)
88
+ value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
89
+ onnx_model.graph.value_info.append(value_info)
90
+ onnx.save_model(onnx_model, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", save_as_external_data=True, all_tensors_to_one_file=True)
91
+ print(f"onnx model chunk{i} saved to {dirname}" + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
92
+
93
+ else:
94
+ args = model.args
95
+ if not args.USE_EMBEDDING:
96
+ model.emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + ".emb")
97
+ args = model.args
98
+ fp16 = args.fp16
99
+ in0 = torch.LongTensor([[1]*seq_length]) if args.USE_EMBEDDING else torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
100
+ states = []
101
+ for i in range(model.args.n_layer):
102
+ states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
103
+ states.append(torch.zeros(model.args.n_head, model.args.head_size, model.args.head_size, dtype=torch.float16 if fp16 else torch.float32))
104
+ states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
105
+ if model.device is not torch.device('cpu'):
106
+ states = [tensor.to(model.device) for tensor in states]
107
+ inputs = {'in0': in0, 'state': states}
108
+ input_names = ['in'] + [f'state{i}_in' for i in range(3*model.args.n_layer)]
109
+ output_names = ['logits'] + [f'state{i}_out' for i in range(3*model.args.n_layer)]
110
+ torch.onnx.export(model, inputs, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", input_names=input_names, output_names=output_names, opset_version=17)
111
+ shape_inference.infer_shapes_path("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
112
+ onnx_model = onnx.load("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
113
+
114
+ # To make model compatible with other frameworks
115
+ for initializer in onnx_model.graph.initializer:
116
+ shape = list(initializer.dims)
117
+ value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
118
+ onnx_model.graph.value_info.append(value_info)
119
+ onnx.save_model(onnx_model, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True)
120
+ print(f"onnx model saved to onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
inference.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import onnxruntime as ort # Uncomment this line to use onnxruntime
2
+ import ztu_somemodelruntime_rknnlite2 as ort # Uncomment this line to use rknnlite2
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from rwkv_tokenizer import RWKV_TOKENIZER
6
+ import time
7
+
8
+ class RWKVModel:
9
+ def __init__(self, model_path: str, tokenizer_path: str = None, use_external_embedding: bool = False):
10
+ # 加载ONNX模型
11
+ session_options = ort.SessionOptions()
12
+ # session_options.core_mask = 7 # 00000111 使用0,1,2三个核心
13
+ self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'], session_options=session_options)
14
+
15
+ # 打印模型输入信息
16
+ print("\nModel inputs:")
17
+ for inp in self.session.get_inputs():
18
+ print(f"{inp.name}: shape={inp.shape}, type={inp.type}")
19
+
20
+ # 获取模型信息
21
+ self.n_layer = len([x for x in self.session.get_inputs() if 'state' in x.name]) // 3
22
+ self.n_embd = self.session.get_inputs()[0].shape[-1] if not use_external_embedding else None
23
+
24
+ # 从模型中获取状态向量的维度
25
+ self.state_shapes = {}
26
+ for inp in self.session.get_inputs():
27
+ if 'state' in inp.name:
28
+ self.state_shapes[inp.name] = inp.shape
29
+
30
+ print("\nNumber of layers:", self.n_layer)
31
+
32
+ # 加载tokenizer
33
+ if tokenizer_path:
34
+ self.tokenizer = RWKV_TOKENIZER(tokenizer_path)
35
+ else:
36
+ self.tokenizer = None
37
+
38
+ # 加载外部embedding(如果需要)
39
+ self.use_external_embedding = use_external_embedding
40
+ if use_external_embedding:
41
+ emb_path = Path(model_path).parent / (Path(model_path).stem + '.emb')
42
+ self.embedding = np.fromfile(emb_path, dtype=np.float32)
43
+ # 重新组织embedding数组的形状
44
+ vocab_size = len(self.embedding) // 768 # 假设embedding维度是768
45
+ self.embedding = self.embedding.reshape(vocab_size, 768)
46
+ self.n_embd = 768
47
+ print(f"\nEmbedding shape: {self.embedding.shape}")
48
+
49
+ # 初始化状态
50
+ self.reset_state()
51
+
52
+ def reset_state(self):
53
+ """重置所有状态为0"""
54
+ self.states = []
55
+ for i in range(self.n_layer * 3):
56
+ state_name = f'state{i}_in'
57
+ state_shape = self.state_shapes[state_name]
58
+ self.states.append(np.zeros(state_shape, dtype=np.float32))
59
+
60
+ def _prepare_inputs(self, token_id):
61
+ """准备模型输入"""
62
+ inputs = {}
63
+
64
+ # 准备主输入
65
+ if self.use_external_embedding:
66
+ # 使用外部embedding
67
+ embedding = self.embedding[token_id].reshape(1, 1, self.n_embd)
68
+ inputs['in'] = embedding.astype(np.float32)
69
+ else:
70
+ # 使用token id
71
+ inputs['in'] = np.array([[token_id]], dtype=np.int64)
72
+
73
+ # 添加状态
74
+ for i in range(len(self.states)):
75
+ inputs[f'state{i}_in'] = self.states[i]
76
+
77
+ # 打印输入shape
78
+ if token_id == 0: # 只打印第一个token的信息
79
+ print("\nPrepared input shapes:")
80
+ for k, v in inputs.items():
81
+ print(f"{k}: shape={v.shape}, type={v.dtype}")
82
+
83
+ return inputs
84
+
85
+ def forward(self, token_id):
86
+ """单步推理"""
87
+ # 准备输入
88
+ inputs = self._prepare_inputs(token_id)
89
+
90
+ # 运行推理
91
+ outputs = self.session.run(None, inputs)
92
+
93
+ # 打印输出信息(仅第一次)
94
+ if token_id == 0:
95
+ print("\nModel outputs:")
96
+ for i, out in enumerate(outputs):
97
+ print(f"Output {i}: shape={out.shape}, type={out.dtype}")
98
+
99
+ # 更新状态
100
+ for i in range(len(self.states)):
101
+ new_state = outputs[i + 1] # 第一个输出是logits
102
+ # 确保维度匹配
103
+ if new_state.shape != self.states[i].shape:
104
+ if token_id == 0:
105
+ print(f"\nState shape mismatch for state{i}_in:")
106
+ print(f"Expected: {self.states[i].shape}")
107
+ print(f"Got: {new_state.shape}")
108
+ # 处理维度
109
+ if len(self.states[i].shape) == 2: # (1, 768)
110
+ new_state = new_state.squeeze(1) # (1, 1, 768) -> (1, 768)
111
+ elif len(self.states[i].shape) == 3: # (12, 64, 64)
112
+ new_state = new_state.squeeze(0) # (1, 12, 64, 64) -> (12, 64, 64)
113
+ self.states[i] = new_state
114
+
115
+ return outputs[0] # 返回logits
116
+
117
+ def generate(self, prompt: str, max_length: int = 100, temperature: float = 1.0, stop_tokens: set = None):
118
+ """生成文本"""
119
+ if not self.tokenizer:
120
+ raise ValueError("需要提供tokenizer才能进行文本生成")
121
+
122
+ # 编码prompt
123
+ tokens = self.tokenizer.encode(prompt)
124
+ generated = list(tokens)
125
+
126
+ # 重置状态
127
+ self.reset_state()
128
+
129
+ # 处理prompt
130
+ print("\nProcessing prompt...", end='', flush=True)
131
+ t_start = time.time()
132
+ for token in tokens:
133
+ logits = self.forward(token)
134
+ t_prompt = time.time() - t_start
135
+ print(f" Done. ({len(tokens)} tokens, {t_prompt:.2f}s, {len(tokens)/t_prompt:.2f} tokens/s)")
136
+
137
+ # 生成新token
138
+ print("\nGenerating:", end='', flush=True)
139
+ t_start = time.time()
140
+ generated_tokens = 0
141
+
142
+ for i in range(max_length):
143
+ # 获取logits并应用temperature
144
+ t_token_start = time.time()
145
+ logits = self.forward(generated[-1])
146
+
147
+ # 打印第一次生成的logits信息
148
+ if i == 0:
149
+ print(f"\nLogits shape: {logits.shape}")
150
+
151
+ # 确保logits是1维的
152
+ logits = logits.reshape(-1) # 展平成1维
153
+
154
+ if temperature > 0:
155
+ # 应用temperature并计算概率
156
+ logits = logits / temperature
157
+ # 减去最大值以避免exp溢出
158
+ logits = logits - np.max(logits)
159
+ probs = np.exp(logits)
160
+ probs = probs / np.sum(probs)
161
+ next_token = np.random.choice(len(probs), p=probs)
162
+ else:
163
+ next_token = np.argmax(logits)
164
+
165
+ generated.append(next_token)
166
+ generated_tokens += 1
167
+
168
+ # 检查是否生成了停止标记
169
+ if stop_tokens and next_token in stop_tokens:
170
+ break
171
+
172
+ # 实时输出新生成的token
173
+ new_text = self.tokenizer.decode([next_token])
174
+ print(new_text, end='', flush=True)
175
+
176
+
177
+
178
+ t_generate = time.time() - t_start
179
+ print(f"\n\nGeneration finished: {generated_tokens} tokens generated in {t_generate:.2f}s ({generated_tokens/t_generate:.2f} tokens/s)")
180
+
181
+ return self.tokenizer.decode(generated)
182
+
183
+ def main():
184
+ import time
185
+
186
+ # 使用示例
187
+ print("Loading model...")
188
+ t_start = time.time()
189
+ model = RWKVModel(
190
+ model_path='RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx',
191
+ tokenizer_path='rwkv_vocab_v20230424.txt',
192
+ use_external_embedding=True
193
+ )
194
+ print(f"Model loaded in {time.time() - t_start:.2f}s")
195
+
196
+ prompt = "Here is a example of Quick Sort algorithm implemented in C++:\n```cpp"
197
+ print(f"\nPrompt: {prompt}")
198
+
199
+ generated_text = model.generate(
200
+ prompt=prompt,
201
+ max_length=1024,
202
+ temperature=0.7,
203
+ stop_tokens={0, 1, 2, 3} # 特殊token作为停止标记
204
+ )
205
+
206
+ print("\nFull text:")
207
+ print(generated_text)
208
+
209
+ if __name__ == '__main__':
210
+ main()
rwkv_tokenizer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List,Set,Dict
2
+
3
+ class ABCTokenizer():
4
+ def __init__(self):
5
+ self.pad_token_id = 0
6
+ self.bos_token_id = 2
7
+ self.eos_token_id = 3
8
+ def encode(self, text):
9
+ ids = [ord(c) for c in text]
10
+ return ids
11
+ def decode(self, ids):
12
+ txt = ''.join(chr(idx) if idx > self.eos_token_id else '' for idx in ids if idx != self.eos_token_id)
13
+ return txt
14
+
15
+ class RWKV_TOKENIZER():
16
+ table: List[List[List[bytes]]]
17
+ good: List[Set[int]]
18
+ wlen: List[int]
19
+ def __init__(self, file_name):
20
+ self.idx2token = {}
21
+ sorted = [] # must be already sorted
22
+ lines = open(file_name, "r", encoding="utf-8").readlines()
23
+ for l in lines:
24
+ idx = int(l[:l.index(' ')])
25
+ x = eval(l[l.index(' '):l.rindex(' ')])
26
+ x = x.encode("utf-8") if isinstance(x, str) else x
27
+ assert isinstance(x, bytes)
28
+ assert len(x) == int(l[l.rindex(' '):])
29
+ sorted += [x]
30
+ self.idx2token[idx] = x
31
+
32
+ self.token2idx = {}
33
+ for k, v in self.idx2token.items():
34
+ self.token2idx[v] = int(k)
35
+
36
+ # precompute some tables for fast matching
37
+ self.table = [[[] for j in range(256)] for i in range(256)]
38
+ self.good = [set() for i in range(256)]
39
+ self.wlen = [0 for i in range(256)]
40
+
41
+ for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
42
+ s = sorted[i]
43
+ if len(s) >= 2:
44
+ s0 = int(s[0])
45
+ s1 = int(s[1])
46
+ self.table[s0][s1] += [s]
47
+ self.wlen[s0] = max(self.wlen[s0], len(s))
48
+ self.good[s0].add(s1)
49
+
50
+ def encodeBytes(self, src: bytes) -> List[int]:
51
+ src_len: int = len(src)
52
+ tokens: List[int] = []
53
+ i: int = 0
54
+ while i < src_len:
55
+ s: bytes = src[i : i + 1]
56
+
57
+ if i < src_len - 1:
58
+ s1: int = int(src[i + 1])
59
+ s0: int = int(src[i])
60
+ if s1 in self.good[s0]:
61
+ sss: bytes = src[i : i + self.wlen[s0]]
62
+ try:
63
+ s = next(filter(sss.startswith, self.table[s0][s1]))
64
+ except:
65
+ pass
66
+ tokens.append(self.token2idx[s])
67
+ i += len(s)
68
+
69
+ return tokens
70
+
71
+ def decodeBytes(self, tokens):
72
+ return b''.join(map(lambda i: self.idx2token[i], tokens))
73
+
74
+ def encode(self, src: str):
75
+ return self.encodeBytes(src.encode("utf-8"))
76
+
77
+ def decode(self, tokens):
78
+ return self.decodeBytes(tokens).decode('utf-8')
79
+
80
+ def printTokens(self, tokens):
81
+ for i in tokens:
82
+ s = self.idx2token[i]
83
+ try:
84
+ s = s.decode('utf-8')
85
+ except:
86
+ pass
87
+ print(f'{repr(s)}{i}', end=' ')
88
+ # print(repr(s), i)
89
+ print()
rwkv_vocab_v20230424.txt ADDED
The diff for this file is too large to render. See raw diff
 
ztu_somemodelruntime_rknnlite2.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 模块级常量和函数
2
+ from rknnlite.api import RKNNLite
3
+ import numpy as np
4
+ import os
5
+ import warnings
6
+ import logging
7
+ from typing import List, Dict, Union, Optional
8
+
9
+ # 配置日志
10
+ logger = logging.getLogger("somemodelruntime_rknnlite2")
11
+ logger.setLevel(logging.ERROR) # 默认只输出错误信息
12
+ if not logger.handlers:
13
+ handler = logging.StreamHandler()
14
+ handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
15
+ logger.addHandler(handler)
16
+
17
+ # ONNX Runtime日志级别到Python logging级别的映射
18
+ _LOGGING_LEVEL_MAP = {
19
+ 0: logging.DEBUG, # Verbose
20
+ 1: logging.INFO, # Info
21
+ 2: logging.WARNING, # Warning
22
+ 3: logging.ERROR, # Error
23
+ 4: logging.CRITICAL # Fatal
24
+ }
25
+
26
+ def set_default_logger_severity(level: int) -> None:
27
+ """
28
+ Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
29
+
30
+ Args:
31
+ level: 日志级别(0-4)
32
+ """
33
+ if level not in _LOGGING_LEVEL_MAP:
34
+ raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
35
+ logger.setLevel(_LOGGING_LEVEL_MAP[level])
36
+
37
+ def set_default_logger_verbosity(level: int) -> None:
38
+ """
39
+ Sets the default logging verbosity level. To activate the verbose log,
40
+ you need to set the default logging severity to 0:Verbose level.
41
+
42
+ Args:
43
+ level: 日志级别(0-4)
44
+ """
45
+ set_default_logger_severity(level)
46
+
47
+ # NPU核心模式常量
48
+ NPU_CORE_AUTO = 0 # 自动选择
49
+ NPU_CORE_0 = 1 # 使用核心0
50
+ NPU_CORE_1 = 2 # 使用核心1
51
+ NPU_CORE_2 = 4 # 使用核心2
52
+ NPU_CORE_0_1 = 3 # 使用核心0和1
53
+ NPU_CORE_0_1_2 = 7 # 使用所有核心
54
+ NPU_CORE_ALL = 0xffff # 使用所有核心
55
+
56
+ # RKNN tensor type到numpy dtype的映射
57
+ RKNN_DTYPE_MAP = {
58
+ 0: np.float32, # RKNN_TENSOR_FLOAT32
59
+ 1: np.float16, # RKNN_TENSOR_FLOAT16
60
+ 2: np.int8, # RKNN_TENSOR_INT8
61
+ 3: np.uint8, # RKNN_TENSOR_UINT8
62
+ 4: np.int16, # RKNN_TENSOR_INT16
63
+ 5: np.uint16, # RKNN_TENSOR_UINT16
64
+ 6: np.int32, # RKNN_TENSOR_INT32
65
+ 7: np.uint32, # RKNN_TENSOR_UINT32
66
+ 8: np.int64, # RKNN_TENSOR_INT64
67
+ 9: bool, # RKNN_TENSOR_BOOL
68
+ 10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
69
+ }
70
+
71
+ def get_available_providers() -> List[str]:
72
+ """
73
+ 获取可用的设备提供者列表(为保持接口兼容性的占位函数)
74
+
75
+ Returns:
76
+ list: 可用的设备提供者列表,总是返回["CPUExecutionProvider"]
77
+ """
78
+ return ["CPUExecutionProvider"]
79
+
80
+ def get_version_info() -> Dict[str, str]:
81
+ """
82
+ 获取版本信息
83
+
84
+ Returns:
85
+ dict: 包含API和驱动版本信息的字典
86
+ """
87
+ runtime = RKNNLite()
88
+ version = runtime.get_sdk_version()
89
+ return {
90
+ "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
91
+ "driver_version": version.split('\n')[3].split(': ')[1]
92
+ }
93
+
94
+ class IOTensor:
95
+ """输入/输出张量的信息封装类"""
96
+ def __init__(self, name, shape, type=None):
97
+ self.name = name.decode() if isinstance(name, bytes) else name
98
+ self.shape = shape
99
+ self.type = type
100
+
101
+ def __str__(self):
102
+ return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
103
+
104
+ class SessionOptions:
105
+ """会话选项类"""
106
+ def __init__(self):
107
+ self.async_mode = False # 是否使用异步模式
108
+ self.core_mask = 0 # NPU核心选择
109
+ self.perf_debug = False # 是否启用性能分析
110
+
111
+ class InferenceSession:
112
+ """
113
+ RKNNLite运行时封装类,API风格类似ONNX Runtime
114
+ """
115
+
116
+ def __init__(self, model_path: str, verbose: bool = False, session_options: Optional[SessionOptions] = None, **kwargs):
117
+ """
118
+ 初始化运行时并加载模型
119
+
120
+ Args:
121
+ model_path: 模型文件路径(.rknn或.onnx)
122
+ verbose: 是否打印详细日志
123
+ session_options: 会话选项
124
+ **kwargs: 其他初始化参数
125
+ """
126
+ # 只在verbose=True时开启详细日志
127
+ if verbose:
128
+ set_default_logger_severity(0) # Verbose
129
+
130
+ self.model_path = self._process_model_path(model_path)
131
+ self.runtime = RKNNLite(verbose=verbose)
132
+
133
+ # 加载模型
134
+ logger.debug(f"正在加载模型: {self.model_path}")
135
+ ret = self.runtime.load_rknn(self.model_path)
136
+ if ret != 0:
137
+ logger.error(f"加载RKNN模型失败: {self.model_path}")
138
+ raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
139
+ logger.debug("模型加载成功")
140
+
141
+ # 应用会话选项
142
+ options = session_options or SessionOptions()
143
+
144
+ # 初始化运行时
145
+ logger.debug("正在初始化运行时环境")
146
+ ret = self.runtime.init_runtime(
147
+ async_mode=options.async_mode,
148
+ core_mask=options.core_mask
149
+ )
150
+ if ret != 0:
151
+ logger.error("初始化运行时环境失败")
152
+ raise RuntimeError('初始化运行时环境失败')
153
+ logger.debug("运行时环境初始化成功")
154
+
155
+ # 获取输入输出信息
156
+ self._init_io_info()
157
+
158
+ # 保存选项
159
+ self.options = options
160
+
161
+ def get_performance_info(self) -> Dict[str, float]:
162
+ """
163
+ 获取性能信息
164
+
165
+ Returns:
166
+ dict: 包含性能信息的字典
167
+ """
168
+ if not self.options.perf_debug:
169
+ raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
170
+
171
+ perf = self.runtime.rknn_runtime.get_run_perf()
172
+ return {
173
+ "run_duration": perf.run_duration / 1000.0 # 转换为毫秒
174
+ }
175
+
176
+ def set_core_mask(self, core_mask: int) -> None:
177
+ """
178
+ 设置NPU核心使用模式
179
+
180
+ Args:
181
+ core_mask: NPU核心掩码,使用NPU_CORE_*常量
182
+ """
183
+ ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
184
+ if ret != 0:
185
+ raise RuntimeError("设置NPU核心模式失败")
186
+
187
+ def _process_model_path(self, model_path):
188
+ """处理模型路径,支持.onnx和.rknn文件"""
189
+ if not os.path.exists(model_path):
190
+ logger.error(f"模型文件不存在: {model_path}")
191
+ raise FileNotFoundError(f"模型文件不存在: {model_path}")
192
+
193
+ # 如果是ONNX文件
194
+ if model_path.lower().endswith('.onnx'):
195
+ logger.warning(
196
+ "检测到ONNX模型文件。注意:SomeModelRuntime不会自动转换ONNX到RKNN。"
197
+ "请先使用RKNN Toolkit转换模型。"
198
+ "现在尝试加载同名的.rknn文件。"
199
+ )
200
+ # 构造RKNN文件路径
201
+ rknn_path = os.path.splitext(model_path)[0] + '.rknn'
202
+ if not os.path.exists(rknn_path):
203
+ logger.error(f"RKNN模型文件不存在: {rknn_path}")
204
+ raise FileNotFoundError(
205
+ f"RKNN模型文件不存在: {rknn_path}\n"
206
+ "请先使用RKNN Toolkit将ONNX模型转换为RKNN格式。"
207
+ )
208
+ return rknn_path
209
+
210
+ return model_path
211
+
212
+ def _convert_nhwc_to_nchw(self, shape):
213
+ """将NHWC格式的shape转换为NCHW格式"""
214
+ if len(shape) == 4:
215
+ # NHWC -> NCHW
216
+ n, h, w, c = shape
217
+ return [n, c, h, w]
218
+ return shape
219
+
220
+ def _init_io_info(self):
221
+ """初始化模型的输入输出信息"""
222
+ runtime = self.runtime.rknn_runtime
223
+
224
+ # 获取输入输出数量
225
+ n_input, n_output = runtime.get_in_out_num()
226
+
227
+ # 获取输入信息
228
+ self.input_tensors = []
229
+ for i in range(n_input):
230
+ attr = runtime.get_tensor_attr(i)
231
+ shape = [attr.dims[j] for j in range(attr.n_dims)]
232
+ # 对四维输入进行NHWC到NCHW的转换
233
+ shape = self._convert_nhwc_to_nchw(shape)
234
+ # 获取dtype
235
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
236
+ tensor = IOTensor(attr.name, shape, dtype)
237
+ self.input_tensors.append(tensor)
238
+
239
+ # 获取输出信息
240
+ self.output_tensors = []
241
+ for i in range(n_output):
242
+ attr = runtime.get_tensor_attr(i, is_output=True)
243
+ shape = runtime.get_output_shape(i)
244
+ # 获取dtype
245
+ dtype = RKNN_DTYPE_MAP.get(attr.type, None)
246
+ tensor = IOTensor(attr.name, shape, dtype)
247
+ self.output_tensors.append(tensor)
248
+
249
+ def get_inputs(self):
250
+ """
251
+ 获取模型输入信息
252
+
253
+ Returns:
254
+ list: 包含输入信息的列表
255
+ """
256
+ return self.input_tensors
257
+
258
+ def get_outputs(self):
259
+ """
260
+ 获取模型输出信息
261
+
262
+ Returns:
263
+ list: 包含输出信息的列表
264
+ """
265
+ return self.output_tensors
266
+
267
+ def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
268
+ """
269
+ 执行模型推理
270
+
271
+ Args:
272
+ output_names: 输出节点名称列表,指定需要返回哪些输出
273
+ input_feed: 输入数据字典或列表
274
+ data_format: 输入数据格式,"nchw"或"nhwc"
275
+ **kwargs: 其他运行时参数
276
+
277
+ Returns:
278
+ list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
279
+ """
280
+ if input_feed is None:
281
+ logger.error("input_feed不能为None")
282
+ raise ValueError("input_feed不能为None")
283
+
284
+ # 准备输入数据
285
+ if isinstance(input_feed, dict):
286
+ # 如果是字典,按照模型输入顺序排列
287
+ inputs = []
288
+ input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
289
+ for tensor in self.input_tensors:
290
+ if tensor.name not in input_feed:
291
+ raise ValueError(f"缺少输入: {tensor.name}")
292
+ inputs.append(input_feed[tensor.name])
293
+ elif isinstance(input_feed, (list, tuple)):
294
+ # 如果是列表,确保长度匹配
295
+ if len(input_feed) != len(self.input_tensors):
296
+ raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
297
+ inputs = list(input_feed)
298
+ else:
299
+ logger.error("input_feed必须是字典或列表类型")
300
+ raise ValueError("input_feed必须是字典或列表类型")
301
+
302
+ # 执行推理
303
+ try:
304
+ logger.debug("开始执行推理")
305
+ all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
306
+
307
+ # 如果没有指定output_names,返回所有输出
308
+ if output_names is None:
309
+ return all_outputs
310
+
311
+ # 获取指定的输出
312
+ output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
313
+ selected_outputs = []
314
+ for name in output_names:
315
+ if name not in output_map:
316
+ raise ValueError(f"未找到输出节点: {name}")
317
+ selected_outputs.append(all_outputs[output_map[name]])
318
+
319
+ return selected_outputs
320
+
321
+ except Exception as e:
322
+ logger.error(f"推理执行失败: {str(e)}")
323
+ raise RuntimeError(f"推理执行失败: {str(e)}")
324
+
325
+ def close(self):
326
+ """
327
+ 关闭会话,释放资源
328
+ """
329
+ if self.runtime is not None:
330
+ logger.info("正在释放运行时资源")
331
+ self.runtime.release()
332
+ self.runtime = None
333
+
334
+ def __enter__(self):
335
+ return self
336
+
337
+ def __exit__(self, exc_type, exc_val, exc_tb):
338
+ self.close()
339
+
340
+ def end_profiling(self) -> Optional[str]:
341
+ """
342
+ 结束性能分析的存根方法
343
+
344
+ Returns:
345
+ Optional[str]: None
346
+ """
347
+ warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
348
+ return None
349
+
350
+ def get_profiling_start_time_ns(self) -> int:
351
+ """
352
+ 获取性能分析开始时间的存根方法
353
+
354
+ Returns:
355
+ int: 0
356
+ """
357
+ warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
358
+ return 0
359
+
360
+ def get_modelmeta(self) -> Dict[str, str]:
361
+ """
362
+ 获取模型元数据的存根方法
363
+
364
+ Returns:
365
+ Dict[str, str]: 空字典
366
+ """
367
+ warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
368
+ return {}
369
+
370
+ def get_session_options(self) -> SessionOptions:
371
+ """
372
+ 获取会话选项
373
+
374
+ Returns:
375
+ SessionOptions: 当前会话选项
376
+ """
377
+ return self.options
378
+
379
+ def get_providers(self) -> List[str]:
380
+ """
381
+ 获取当前使用的providers的存根方法
382
+
383
+ Returns:
384
+ List[str]: ["CPUExecutionProvider"]
385
+ """
386
+ warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
387
+ return ["CPUExecutionProvider"]
388
+
389
+ def get_provider_options(self) -> Dict[str, Dict[str, str]]:
390
+ """
391
+ 获取provider选项的存根方法
392
+
393
+ Returns:
394
+ Dict[str, Dict[str, str]]: 空字典
395
+ """
396
+ warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
397
+ return {}
398
+
399
+ def get_session_config(self) -> Dict[str, str]:
400
+ """
401
+ 获取会话配置的存根方法
402
+
403
+ Returns:
404
+ Dict[str, str]: 空字典
405
+ """
406
+ warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
407
+ return {}
408
+
409
+ def get_session_state(self) -> Dict[str, str]:
410
+ """
411
+ 获取会话状态的存根方法
412
+
413
+ Returns:
414
+ Dict[str, str]: 空字典
415
+ """
416
+ warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
417
+ return {}
418
+
419
+ def set_session_config(self, config: Dict[str, str]) -> None:
420
+ """
421
+ 设置会话配置的存根方法
422
+
423
+ Args:
424
+ config: 会话配置字典
425
+ """
426
+ warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
427
+
428
+ def get_memory_info(self) -> Dict[str, int]:
429
+ """
430
+ 获取内存使用信息的存根方法
431
+
432
+ Returns:
433
+ Dict[str, int]: 空字典
434
+ """
435
+ warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
436
+ return {}
437
+
438
+ def set_memory_pattern(self, enable: bool) -> None:
439
+ """
440
+ 设置内存模式的存根方法
441
+
442
+ Args:
443
+ enable: 是否启用内存模式
444
+ """
445
+ warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
446
+
447
+ def disable_memory_pattern(self) -> None:
448
+ """
449
+ 禁用内存模式的存根方法
450
+ """
451
+ warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
452
+
453
+ def get_optimization_level(self) -> int:
454
+ """
455
+ 获取优化级别的存根方法
456
+
457
+ Returns:
458
+ int: 0
459
+ """
460
+ warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
461
+ return 0
462
+
463
+ def set_optimization_level(self, level: int) -> None:
464
+ """
465
+ 设置优化级别的存根方法
466
+
467
+ Args:
468
+ level: 优化级别
469
+ """
470
+ warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
471
+
472
+ def get_model_metadata(self) -> Dict[str, str]:
473
+ """
474
+ 获取模型元数据的存根方法(与get_modelmeta不同的接口)
475
+
476
+ Returns:
477
+ Dict[str, str]: 空字典
478
+ """
479
+ warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
480
+ return {}
481
+
482
+ def get_model_path(self) -> str:
483
+ """
484
+ 获取模型路径
485
+
486
+ Returns:
487
+ str: 模型文件路径
488
+ """
489
+ return self.model_path
490
+
491
+ def get_input_type_info(self) -> List[Dict[str, str]]:
492
+ """
493
+ 获取输入类型信息的存根方法
494
+
495
+ Returns:
496
+ List[Dict[str, str]]: 空列表
497
+ """
498
+ warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
499
+ return []
500
+
501
+ def get_output_type_info(self) -> List[Dict[str, str]]:
502
+ """
503
+ 获取输出类型信息的存根方法
504
+
505
+ Returns:
506
+ List[Dict[str, str]]: 空列表
507
+ """
508
+ warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
509
+ return []