import torch | |
import argparse | |
def prune_ckpt(ckpt_path, save_path): | |
raw = torch.load(ckpt_path, map_location=torch.device('cpu')) | |
state_dict = raw["state_dict"] | |
torch.save(state_dict, save_path) | |
if __name__ == '__main__': | |
args = argparse.ArgumentParser() | |
args.add_argument('--ckpt_path', type=str) | |
args.add_argument('--save_path', type=str) | |
args = args.parse_args() | |