Spaces:
Configuration error
Configuration error
| from argparse import ArgumentParser | |
| import torch | |
| def simplify_pth(pth_name, project_name): | |
| model_path = f'./checkpoints/{project_name}' | |
| checkpoint_dict = torch.load(f'{model_path}/{pth_name}') | |
| torch.save({'epoch': checkpoint_dict['epoch'], | |
| 'state_dict': checkpoint_dict['state_dict'], | |
| 'global_step': None, | |
| 'checkpoint_callback_best': None, | |
| 'optimizer_states': None, | |
| 'lr_schedulers': None | |
| }, f'./clean_{pth_name}') | |
| def main(): | |
| parser = ArgumentParser() | |
| parser.add_argument('--proj', type=str) | |
| parser.add_argument('--steps', type=str) | |
| args = parser.parse_args() | |
| model_name = f"model_ckpt_steps_{args.steps}.ckpt" | |
| simplify_pth(model_name, args.proj) | |
| if __name__ == '__main__': | |
| main() | |