Spaces:
Running
on
Zero
Running
on
Zero
| import pathlib | |
| import os | |
| import torch | |
| def save(ckpt_dir, module, optimizer, scheduler, global_step, keep_latest=2, model_name='model'): | |
| pathlib.Path(ckpt_dir).mkdir(exist_ok=True, parents=True) | |
| prev_ckpts = list(pathlib.Path(ckpt_dir).glob('%s-*pth' % model_name)) | |
| prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True) | |
| if len(prev_ckpts) > keep_latest-1: | |
| for f in prev_ckpts[keep_latest-1:]: | |
| f.unlink() | |
| save_path = '%s/%s-%09d.pth' % (ckpt_dir, model_name, global_step) | |
| save_dict = { | |
| "model": module.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "global_step": global_step, | |
| } | |
| if scheduler is not None: | |
| save_dict['scheduler'] = scheduler.state_dict() | |
| print(f"saving {save_path}") | |
| torch.save(save_dict, save_path) | |
| return False | |
| def load(fabric, ckpt_path, model, optimizer=None, scheduler=None, model_ema=None, step=0, model_name='model', ignore_load=None, strict=True, verbose=True, weights_only=False): | |
| if verbose: | |
| print('reading ckpt from %s' % ckpt_path) | |
| if not os.path.exists(ckpt_path): | |
| print('...there is no full checkpoint in %s' % ckpt_path) | |
| print('-- note this function no longer appends "saved_checkpoints/" before the ckpt_path --') | |
| assert(False) | |
| else: | |
| if os.path.isfile(ckpt_path): | |
| path = ckpt_path | |
| print('...found checkpoint %s' % (path)) | |
| else: | |
| prev_ckpts = list(pathlib.Path(ckpt_path).glob('%s-*pth' % model_name)) | |
| prev_ckpts.sort(key=lambda p: p.stat().st_mtime,reverse=True) | |
| if len(prev_ckpts): | |
| path = prev_ckpts[0] | |
| # e.g., './checkpoints/2Ai4_5e-4_base18_1539/model-000050000.pth' | |
| # OR ./whatever.pth | |
| step = int(str(path).split('-')[-1].split('.')[0]) | |
| if verbose: | |
| print('...found checkpoint %s; (parsed step %d from path)' % (path, step)) | |
| else: | |
| print('...there is no full checkpoint here!') | |
| return 0 | |
| if fabric is not None: | |
| checkpoint = fabric.load(path) | |
| else: | |
| checkpoint = torch.load(path, weights_only=weights_only) | |
| if optimizer is not None: | |
| optimizer.load_state_dict(checkpoint['optimizer']) | |
| if scheduler is not None: | |
| scheduler.load_state_dict(checkpoint['scheduler']) | |
| assert ignore_load is None # not ready yet | |
| if 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| else: | |
| state_dict = checkpoint | |
| model.load_state_dict(state_dict, strict=strict) | |
| return step | |