from collections import OrderedDict import torch baichuan = torch.load("pytorch_model.bin") llama = OrderedDict() for key in baichuan: if 'W_pack' in key: llama[key.replace('W_pack', 'q_proj')] = baichuan[key][:4096] llama[key.replace('W_pack', 'k_proj')] = baichuan[key][4096:4096 * 2] llama[key.replace('W_pack', 'v_proj')] = baichuan[key][4096 * 2:] else: llama[key] = baichuan[key] torch.save(llama, "pytorch_model.bin")