baichuan-llama-7b / convert_baichuan_to_llama.py
fireballoon's picture
Update convert_baichuan_to_llama.py
7502c3a
raw
history blame
467 Bytes
from collections import OrderedDict
import torch
baichuan = torch.load("pytorch_model.bin")
llama = OrderedDict()
for key in baichuan:
if 'W_pack' in key:
llama[key.replace('W_pack', 'q_proj')] = baichuan[key][:4096]
llama[key.replace('W_pack', 'k_proj')] = baichuan[key][4096:4096 * 2]
llama[key.replace('W_pack', 'v_proj')] = baichuan[key][4096 * 2:]
else:
llama[key] = baichuan[key]
torch.save(llama, "pytorch_model.bin")