from rwkv_src.rwkv_model import RWKV_RNN, make_chunks import types import os import torch import numpy as np import argparse import json import copy from pathlib import Path import onnx from onnx import shape_inference parser = argparse.ArgumentParser(description='Convert model') parser.add_argument('model', type=Path, help='Path to RWKV pth file') parser.add_argument('--chunks', type=int, default=1, help='Number of chunks') parser.add_argument('--ext_embedding', action='store_true', default=False, help='Use external embedding') parser.add_argument('--prefill_model', action='store_true', help='Convert model for sequential prefill') parser.add_argument('--wkv_customop', action='store_true', help='Use custom op for wkv') parser_args = parser.parse_args() seq_length = 32 if parser_args.prefill_model else 1 model_args = types.SimpleNamespace() model_args.USE_CUDA = False model_args.fp16 = False model_args.wkv_customop = parser_args.wkv_customop model_args.USE_EMBEDDING = False if parser_args.ext_embedding else True model_args.MODEL_NAME = str(parser_args.model) if 'ABC' in model_args.MODEL_NAME or 'MIDI' in model_args.MODEL_NAME or 'x070' in model_args.MODEL_NAME: model_args.RESCALE_LAYER = 0 else: model_args.RESCALE_LAYER = 6 model = make_chunks(parser_args.chunks, model_args) if parser_args.chunks > 1 else RWKV_RNN(model_args) if parser_args.prefill_model: model_args.MODEL_NAME = model_args.MODEL_NAME + "_prefill" os.path.exists("onnx") or os.mkdir("onnx") if type(model) == list: args = model[0].args if not args.USE_EMBEDDING: model[0].emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk1of{len(model)}.emb") args = model[0].args fp16 = args.fp16 states = [] for i in range(args.n_layer): states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) states.append(torch.zeros(args.n_head, args.head_size, args.head_size, dtype=torch.float16 if fp16 else torch.float32)) states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) if model[0].device is not torch.device('cpu'): states = [i.to(model[0].device) for i in states] for i in range(len(model)): dirname = "onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}" os.path.exists(dirname) or os.mkdir(dirname) if i == 0 and args.USE_EMBEDDING: in0 = torch.LongTensor([[1]*seq_length]) else: in0 = torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32) if model[0].device is not torch.device('cpu'): in0 = in0.to(model[0].device) inputs = {'in0': in0, 'state': [states[j] for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]} input_names = ['in'] + [f'state{j}_in' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)] output_names = ['out'] + [f'state{j}_out' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)] if args.wkv_customop: from torch.onnx.symbolic_helper import _get_tensor_sizes from torch.onnx import register_custom_op_symbolic op_name = "rwkv::wkv_chunk" if parser_args.prefill_model else "rwkv::wkv" def onnx_custom_wkv(g, k, v, r, state2, time_first, time_decay): out1, out2 = g.op(op_name, k, v, r, state2, time_first, time_decay, outputs=2) return out1.setType(k.type().with_dtype(torch.float32).with_sizes([seq_length, _get_tensor_sizes(k)[0], 1, args.head_size])),\ out2.setType(k.type().with_dtype(torch.float32).with_sizes([1, _get_tensor_sizes(k)[0], args.head_size, args.head_size])) register_custom_op_symbolic(op_name, onnx_custom_wkv, 9) torch.onnx.export(model[i], inputs, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", input_names=input_names, output_names=output_names, opset_version=17) shape_inference.infer_shapes_path(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx") onnx_model = onnx.load(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx") # To make model compatible with other frameworks for initializer in onnx_model.graph.initializer: shape = list(initializer.dims) value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) onnx_model.graph.value_info.append(value_info) onnx.save_model(onnx_model, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", save_as_external_data=True, all_tensors_to_one_file=True) print(f"onnx model chunk{i} saved to {dirname}" + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx") else: args = model.args if not args.USE_EMBEDDING: model.emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + ".emb") args = model.args fp16 = args.fp16 in0 = torch.LongTensor([[1]*seq_length]) if args.USE_EMBEDDING else torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32) states = [] for i in range(model.args.n_layer): states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) states.append(torch.zeros(model.args.n_head, model.args.head_size, model.args.head_size, dtype=torch.float16 if fp16 else torch.float32)) states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) if model.device is not torch.device('cpu'): states = [tensor.to(model.device) for tensor in states] inputs = {'in0': in0, 'state': states} input_names = ['in'] + [f'state{i}_in' for i in range(3*model.args.n_layer)] output_names = ['logits'] + [f'state{i}_out' for i in range(3*model.args.n_layer)] torch.onnx.export(model, inputs, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", input_names=input_names, output_names=output_names, opset_version=17) shape_inference.infer_shapes_path("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx") onnx_model = onnx.load("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx") # To make model compatible with other frameworks for initializer in onnx_model.graph.initializer: shape = list(initializer.dims) value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) onnx_model.graph.value_info.append(value_info) onnx.save_model(onnx_model, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True) print(f"onnx model saved to onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")