|
from rwkv_src.rwkv_model import RWKV_RNN, make_chunks |
|
import types |
|
import os |
|
import torch |
|
import numpy as np |
|
import argparse |
|
import json |
|
import copy |
|
from pathlib import Path |
|
import onnx |
|
from onnx import shape_inference |
|
|
|
parser = argparse.ArgumentParser(description='Convert model') |
|
parser.add_argument('model', type=Path, help='Path to RWKV pth file') |
|
parser.add_argument('--chunks', type=int, default=1, help='Number of chunks') |
|
parser.add_argument('--ext_embedding', action='store_true', default=False, help='Use external embedding') |
|
parser.add_argument('--prefill_model', action='store_true', help='Convert model for sequential prefill') |
|
parser.add_argument('--wkv_customop', action='store_true', help='Use custom op for wkv') |
|
parser_args = parser.parse_args() |
|
|
|
seq_length = 32 if parser_args.prefill_model else 1 |
|
|
|
model_args = types.SimpleNamespace() |
|
model_args.USE_CUDA = False |
|
model_args.fp16 = False |
|
model_args.wkv_customop = parser_args.wkv_customop |
|
model_args.USE_EMBEDDING = False if parser_args.ext_embedding else True |
|
|
|
model_args.MODEL_NAME = str(parser_args.model) |
|
|
|
if 'ABC' in model_args.MODEL_NAME or 'MIDI' in model_args.MODEL_NAME or 'x070' in model_args.MODEL_NAME: |
|
model_args.RESCALE_LAYER = 0 |
|
else: |
|
model_args.RESCALE_LAYER = 6 |
|
|
|
model = make_chunks(parser_args.chunks, model_args) if parser_args.chunks > 1 else RWKV_RNN(model_args) |
|
|
|
if parser_args.prefill_model: |
|
model_args.MODEL_NAME = model_args.MODEL_NAME + "_prefill" |
|
|
|
os.path.exists("onnx") or os.mkdir("onnx") |
|
|
|
if type(model) == list: |
|
args = model[0].args |
|
if not args.USE_EMBEDDING: |
|
model[0].emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk1of{len(model)}.emb") |
|
args = model[0].args |
|
fp16 = args.fp16 |
|
states = [] |
|
for i in range(args.n_layer): |
|
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) |
|
states.append(torch.zeros(args.n_head, args.head_size, args.head_size, dtype=torch.float16 if fp16 else torch.float32)) |
|
states.append(torch.zeros(1, args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) |
|
if model[0].device is not torch.device('cpu'): |
|
states = [i.to(model[0].device) for i in states] |
|
|
|
for i in range(len(model)): |
|
dirname = "onnx/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}" |
|
os.path.exists(dirname) or os.mkdir(dirname) |
|
if i == 0 and args.USE_EMBEDDING: |
|
in0 = torch.LongTensor([[1]*seq_length]) |
|
else: |
|
in0 = torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32) |
|
|
|
if model[0].device is not torch.device('cpu'): |
|
in0 = in0.to(model[0].device) |
|
inputs = {'in0': in0, 'state': [states[j] for j in range(3*model[i].layer_begin, 3*model[i].layer_end)]} |
|
input_names = ['in'] + [f'state{j}_in' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)] |
|
output_names = ['out'] + [f'state{j}_out' for j in range(3*model[i].layer_begin, 3*model[i].layer_end)] |
|
|
|
if args.wkv_customop: |
|
from torch.onnx.symbolic_helper import _get_tensor_sizes |
|
from torch.onnx import register_custom_op_symbolic |
|
op_name = "rwkv::wkv_chunk" if parser_args.prefill_model else "rwkv::wkv" |
|
def onnx_custom_wkv(g, k, v, r, state2, time_first, time_decay): |
|
out1, out2 = g.op(op_name, k, v, r, state2, time_first, time_decay, outputs=2) |
|
return out1.setType(k.type().with_dtype(torch.float32).with_sizes([seq_length, _get_tensor_sizes(k)[0], 1, args.head_size])),\ |
|
out2.setType(k.type().with_dtype(torch.float32).with_sizes([1, _get_tensor_sizes(k)[0], args.head_size, args.head_size])) |
|
register_custom_op_symbolic(op_name, onnx_custom_wkv, 9) |
|
|
|
torch.onnx.export(model[i], inputs, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", input_names=input_names, output_names=output_names, opset_version=17) |
|
shape_inference.infer_shapes_path(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx") |
|
onnx_model = onnx.load(dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx") |
|
|
|
|
|
for initializer in onnx_model.graph.initializer: |
|
shape = list(initializer.dims) |
|
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) |
|
onnx_model.graph.value_info.append(value_info) |
|
onnx.save_model(onnx_model, dirname + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx", save_as_external_data=True, all_tensors_to_one_file=True) |
|
print(f"onnx model chunk{i} saved to {dirname}" + "/" + args.MODEL_NAME.split("/")[-1] + f"_chunk{i+1}of{len(model)}.onnx") |
|
|
|
else: |
|
args = model.args |
|
if not args.USE_EMBEDDING: |
|
model.emb_weight.cpu().numpy().astype(np.float32).tofile("onnx/" + args.MODEL_NAME.split("/")[-1] + ".emb") |
|
args = model.args |
|
fp16 = args.fp16 |
|
in0 = torch.LongTensor([[1]*seq_length]) if args.USE_EMBEDDING else torch.zeros(1, seq_length, args.n_embd, dtype=torch.float16 if fp16 else torch.float32) |
|
states = [] |
|
for i in range(model.args.n_layer): |
|
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) |
|
states.append(torch.zeros(model.args.n_head, model.args.head_size, model.args.head_size, dtype=torch.float16 if fp16 else torch.float32)) |
|
states.append(torch.zeros(1, model.args.n_embd, dtype=torch.float16 if fp16 else torch.float32)) |
|
if model.device is not torch.device('cpu'): |
|
states = [tensor.to(model.device) for tensor in states] |
|
inputs = {'in0': in0, 'state': states} |
|
input_names = ['in'] + [f'state{i}_in' for i in range(3*model.args.n_layer)] |
|
output_names = ['logits'] + [f'state{i}_out' for i in range(3*model.args.n_layer)] |
|
torch.onnx.export(model, inputs, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", input_names=input_names, output_names=output_names, opset_version=17) |
|
shape_inference.infer_shapes_path("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx") |
|
onnx_model = onnx.load("onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx") |
|
|
|
|
|
for initializer in onnx_model.graph.initializer: |
|
shape = list(initializer.dims) |
|
value_info = onnx.helper.make_tensor_value_info(initializer.name, initializer.data_type, shape) |
|
onnx_model.graph.value_info.append(value_info) |
|
onnx.save_model(onnx_model, "onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True) |
|
print(f"onnx model saved to onnx/" + args.MODEL_NAME.split("/")[-1] + ".onnx") |
|
|