File size: 467 Bytes
3cdc5c4
 
 
 
 
 
 
 
 
 
 
 
 
7502c3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from collections import OrderedDict
import torch

baichuan = torch.load("pytorch_model.bin")
llama = OrderedDict()

for key in baichuan:
    if 'W_pack' in key:
        llama[key.replace('W_pack', 'q_proj')] = baichuan[key][:4096]
        llama[key.replace('W_pack', 'k_proj')] = baichuan[key][4096:4096 * 2]
        llama[key.replace('W_pack', 'v_proj')] = baichuan[key][4096 * 2:]
    else:
        llama[key] = baichuan[key]
torch.save(llama, "pytorch_model.bin")