File size: 395 Bytes
45b4aa7
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import argparse

def prune_ckpt(ckpt_path, save_path):
    raw = torch.load(ckpt_path, map_location=torch.device('cpu'))
    state_dict = raw["state_dict"]
    torch.save(state_dict, save_path)

if __name__ == '__main__':
    args = argparse.ArgumentParser()
    args.add_argument('--ckpt_path', type=str)
    args.add_argument('--save_path', type=str)
    args = args.parse_args()