# import onnxruntime as ort # Uncomment this line to use onnxruntime import ztu_somemodelruntime_rknnlite2 as ort # Uncomment this line to use rknnlite2 import numpy as np from pathlib import Path from rwkv_tokenizer import RWKV_TOKENIZER import time class RWKVModel: def __init__(self, model_path: str, tokenizer_path: str = None, use_external_embedding: bool = False): # 加载ONNX模型 session_options = ort.SessionOptions() # session_options.core_mask = 7 # 00000111 使用0,1,2三个核心 self.session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'], session_options=session_options) # 打印模型输入信息 print("\nModel inputs:") for inp in self.session.get_inputs(): print(f"{inp.name}: shape={inp.shape}, type={inp.type}") # 获取模型信息 self.n_layer = len([x for x in self.session.get_inputs() if 'state' in x.name]) // 3 self.n_embd = self.session.get_inputs()[0].shape[-1] if not use_external_embedding else None # 从模型中获取状态向量的维度 self.state_shapes = {} for inp in self.session.get_inputs(): if 'state' in inp.name: self.state_shapes[inp.name] = inp.shape print("\nNumber of layers:", self.n_layer) # 加载tokenizer if tokenizer_path: self.tokenizer = RWKV_TOKENIZER(tokenizer_path) else: self.tokenizer = None # 加载外部embedding(如果需要) self.use_external_embedding = use_external_embedding if use_external_embedding: emb_path = Path(model_path).parent / (Path(model_path).stem + '.emb') self.embedding = np.fromfile(emb_path, dtype=np.float32) # 重新组织embedding数组的形状 vocab_size = len(self.embedding) // 768 # 假设embedding维度是768 self.embedding = self.embedding.reshape(vocab_size, 768) self.n_embd = 768 print(f"\nEmbedding shape: {self.embedding.shape}") # 初始化状态 self.reset_state() def reset_state(self): """重置所有状态为0""" self.states = [] for i in range(self.n_layer * 3): state_name = f'state{i}_in' state_shape = self.state_shapes[state_name] self.states.append(np.zeros(state_shape, dtype=np.float32)) def _prepare_inputs(self, token_id): """准备模型输入""" inputs = {} # 准备主输入 if self.use_external_embedding: # 使用外部embedding embedding = self.embedding[token_id].reshape(1, 1, self.n_embd) inputs['in'] = embedding.astype(np.float32) else: # 使用token id inputs['in'] = np.array([[token_id]], dtype=np.int64) # 添加状态 for i in range(len(self.states)): inputs[f'state{i}_in'] = self.states[i] # 打印输入shape if token_id == 0: # 只打印第一个token的信息 print("\nPrepared input shapes:") for k, v in inputs.items(): print(f"{k}: shape={v.shape}, type={v.dtype}") return inputs def forward(self, token_id): """单步推理""" # 准备输入 inputs = self._prepare_inputs(token_id) # 运行推理 outputs = self.session.run(None, inputs) # 打印输出信息(仅第一次) if token_id == 0: print("\nModel outputs:") for i, out in enumerate(outputs): print(f"Output {i}: shape={out.shape}, type={out.dtype}") # 更新状态 for i in range(len(self.states)): new_state = outputs[i + 1] # 第一个输出是logits # 确保维度匹配 if new_state.shape != self.states[i].shape: if token_id == 0: print(f"\nState shape mismatch for state{i}_in:") print(f"Expected: {self.states[i].shape}") print(f"Got: {new_state.shape}") # 处理维度 if len(self.states[i].shape) == 2: # (1, 768) new_state = new_state.squeeze(1) # (1, 1, 768) -> (1, 768) elif len(self.states[i].shape) == 3: # (12, 64, 64) new_state = new_state.squeeze(0) # (1, 12, 64, 64) -> (12, 64, 64) self.states[i] = new_state return outputs[0] # 返回logits def generate(self, prompt: str, max_length: int = 100, temperature: float = 1.0, stop_tokens: set = None): """生成文本""" if not self.tokenizer: raise ValueError("需要提供tokenizer才能进行文本生成") # 编码prompt tokens = self.tokenizer.encode(prompt) generated = list(tokens) # 重置状态 self.reset_state() # 处理prompt print("\nProcessing prompt...", end='', flush=True) t_start = time.time() for token in tokens: logits = self.forward(token) t_prompt = time.time() - t_start print(f" Done. ({len(tokens)} tokens, {t_prompt:.2f}s, {len(tokens)/t_prompt:.2f} tokens/s)") # 生成新token print("\nGenerating:", end='', flush=True) t_start = time.time() generated_tokens = 0 for i in range(max_length): # 获取logits并应用temperature t_token_start = time.time() logits = self.forward(generated[-1]) # 打印第一次生成的logits信息 if i == 0: print(f"\nLogits shape: {logits.shape}") # 确保logits是1维的 logits = logits.reshape(-1) # 展平成1维 if temperature > 0: # 应用temperature并计算概率 logits = logits / temperature # 减去最大值以避免exp溢出 logits = logits - np.max(logits) probs = np.exp(logits) probs = probs / np.sum(probs) next_token = np.random.choice(len(probs), p=probs) else: next_token = np.argmax(logits) generated.append(next_token) generated_tokens += 1 # 检查是否生成了停止标记 if stop_tokens and next_token in stop_tokens: break # 实时输出新生成的token new_text = self.tokenizer.decode([next_token]) print(new_text, end='', flush=True) t_generate = time.time() - t_start print(f"\n\nGeneration finished: {generated_tokens} tokens generated in {t_generate:.2f}s ({generated_tokens/t_generate:.2f} tokens/s)") return self.tokenizer.decode(generated) def main(): import time # 使用示例 print("Loading model...") t_start = time.time() model = RWKVModel( model_path='RWKV-x070-World-0.1B-v2.8-20241210-ctx4096.onnx', tokenizer_path='rwkv_vocab_v20230424.txt', use_external_embedding=True ) print(f"Model loaded in {time.time() - t_start:.2f}s") prompt = "Here is a example of Quick Sort algorithm implemented in C++:\n```cpp" print(f"\nPrompt: {prompt}") generated_text = model.generate( prompt=prompt, max_length=1024, temperature=0.7, stop_tokens={0, 1, 2, 3} # 特殊token作为停止标记 ) print("\nFull text:") print(generated_text) if __name__ == '__main__': main()