rwkv-7-world-ONNX-RKNN2 / export_onnx.py
happyme531's picture
Upload 10 files
0053ecb verified
raw
history blame
6.89 kB
from rwkv_src.rwkv_model import RWKV_RNN, make_chunks
import types
import os
import torch
import numpy as np
import argparse
import json
import copy
from pathlib import Path
import onnx
from onnx import shape_inference
parser = argparse.ArgumentParser(description='Convert model')
parser.add_argument('model', type=Path, help='Path to RWKV pth file')
parser.add_argument('--chunks', type=int, default=1, help='Number of chunks')
parser.add_argument('--ext_embedding', action='store_true', default=False, help='Use external embedding')
parser.add_argument('--prefill_model', action='store_true', help='Convert model for sequential prefill')
parser.add_argument('--wkv_customop', action='store_true', help='Use custom op for wkv')
parser_args = parser.parse_args()
seq_length = 32 if parser_args.prefill_model else 1
model_args = types.SimpleNamespace()
model_args.USE_CUDA = False
model_args.fp16 = False
model_args.wkv_customop = parser_args.wkv_customop
model_args.USE_EMBEDDING = False if parser_args.ext_embedding else True
model_args.MODEL_NAME = str(parser_args.model)
if 'ABC' in model_args.MODEL_NAME or 'MIDI' in model_args.MODEL_NAME or 'x070' in model_args.MODEL_NAME:
model_args.RESCALE_LAYER = 0
else:
model_args.RESCALE_LAYER = 6
model = make_chunks(parser_args.chunks, model_args) if parser_args.chunks > 1 else RWKV_RNN(model_args)
if parser_args.prefill_model:
model_args.MODEL_NAME = model_args.MODEL_NAME + "_prefill"
os.path.exists("onnx") or os.mkdir("onnx")
if type(model) == list:
args = model[0].args
if not args.USE_EMBEDDING:
model[0].emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk1of{len(model)}.emb")
args = model[0].args
fp16 = args.fp16
states = []
for i in range(args.n_layer):
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
states.append(torch.zeros(args.n_head, args.head_size, args.head_size, dtype=torch.float16 if fp16 else torch.float32))
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
if model[0].device is not torch.device('cpu'):
states = [i.to(model[0].device) for i in states]
for i in range(len(model)):
dirname = "onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}"
os.path.exists(dirname) or os.mkdir(dirname)
if i == 0 and args.USE_EMBEDDING:
in0 = torch.LongTensor([[1]*seq_length])
else:
in0 = torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
if model[0].device is not torch.device('cpu'):
in0 = in0.to(model[0].device)
inputs = {'in0': in0, 'state': [states[j] for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]}
input_names = ['in'] + [f'state{j}_in' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
output_names = ['out'] + [f'state{j}_out' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]
if args.wkv_customop:
from torch.onnx.symbolic_helper import _get_tensor_sizes
from torch.onnx import register_custom_op_symbolic
op_name = "rwkv::wkv_chunk" if parser_args.prefill_model else "rwkv::wkv"
def onnx_custom_wkv(g, k, v, r, state2, time_first, time_decay):
out1, out2 = g.op(op_name, k, v, r, state2, time_first, time_decay, outputs=2)
return out1.setType(k.type().with_dtype(torch.float32).with_sizes([seq_length, _get_tensor_sizes(k)[0], 1, args.head_size])),\
out2.setType(k.type().with_dtype(torch.float32).with_sizes([1, _get_tensor_sizes(k)[0], args.head_size, args.head_size]))
register_custom_op_symbolic(op_name, onnx_custom_wkv, 9)
torch.onnx.export(model[i], inputs, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", input_names=input_names, output_names=output_names, opset_version=17)
shape_inference.infer_shapes_path(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
onnx_model = onnx.load(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
# To make model compatible with other frameworks
for initializer in onnx_model.graph.initializer:
shape = list(initializer.dims)
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
onnx_model.graph.value_info.append(value_info)
onnx.save_model(onnx_model, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", save_as_external_data=True, all_tensors_to_one_file=True)
print(f"onnx model chunk{i} saved to {dirname}" + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx")
else:
args = model.args
if not args.USE_EMBEDDING:
model.emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + ".emb")
args = model.args
fp16 = args.fp16
in0 = torch.LongTensor([[1]*seq_length]) if args.USE_EMBEDDING else torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)
states = []
for i in range(model.args.n_layer):
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
states.append(torch.zeros(model.args.n_head, model.args.head_size, model.args.head_size, dtype=torch.float16 if fp16 else torch.float32))
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32))
if model.device is not torch.device('cpu'):
states = [tensor.to(model.device) for tensor in states]
inputs = {'in0': in0, 'state': states}
input_names = ['in'] + [f'state{i}_in' for i in range(3*model.args.n_layer)]
output_names = ['logits'] + [f'state{i}_out' for i in range(3*model.args.n_layer)]
torch.onnx.export(model, inputs, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", input_names=input_names, output_names=output_names, opset_version=17)
shape_inference.infer_shapes_path("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
onnx_model = onnx.load("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")
# To make model compatible with other frameworks
for initializer in onnx_model.graph.initializer:
shape = list(initializer.dims)
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape)
onnx_model.graph.value_info.append(value_info)
onnx.save_model(onnx_model, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True)
print(f"onnx model saved to onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx")