Spaces:
Build error
Build error
| import glob | |
| import logging | |
| import os | |
| import re | |
| import torch | |
| def get_last_checkpoint(work_dir, steps=None): | |
| checkpoint = None | |
| last_ckpt_path = None | |
| ckpt_paths = get_all_ckpts(work_dir, steps) | |
| if len(ckpt_paths) > 0: | |
| last_ckpt_path = ckpt_paths[0] | |
| checkpoint = torch.load(last_ckpt_path, map_location='cpu') | |
| logging.info(f'load module from checkpoint: {last_ckpt_path}') | |
| return checkpoint, last_ckpt_path | |
| def get_all_ckpts(work_dir, steps=None): | |
| if steps is None: | |
| ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' | |
| else: | |
| ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' | |
| return sorted(glob.glob(ckpt_path_pattern), | |
| key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) | |
| def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): | |
| if os.path.isfile(ckpt_base_dir): | |
| base_dir = os.path.dirname(ckpt_base_dir) | |
| ckpt_path = ckpt_base_dir | |
| checkpoint = torch.load(ckpt_base_dir, map_location='cpu') | |
| else: | |
| base_dir = ckpt_base_dir | |
| checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) | |
| if checkpoint is not None: | |
| state_dict = checkpoint["state_dict"] | |
| if len([k for k in state_dict.keys() if '.' in k]) > 0: | |
| state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() | |
| if k.startswith(f'{model_name}.')} | |
| else: | |
| if '.' not in model_name: | |
| state_dict = state_dict[model_name] | |
| else: | |
| base_model_name = model_name.split('.')[0] | |
| rest_model_name = model_name[len(base_model_name) + 1:] | |
| state_dict = { | |
| k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() | |
| if k.startswith(f'{rest_model_name}.')} | |
| if not strict: | |
| cur_model_state_dict = cur_model.state_dict() | |
| unmatched_keys = [] | |
| for key, param in state_dict.items(): | |
| if key in cur_model_state_dict: | |
| new_param = cur_model_state_dict[key] | |
| if new_param.shape != param.shape: | |
| unmatched_keys.append(key) | |
| print("| Unmatched keys: ", key, new_param.shape, param.shape) | |
| for key in unmatched_keys: | |
| del state_dict[key] | |
| cur_model.load_state_dict(state_dict, strict=strict) | |
| print(f"| load '{model_name}' from '{ckpt_path}'.") | |
| else: | |
| e_msg = f"| ckpt not found in {base_dir}." | |
| if force: | |
| assert False, e_msg | |
| else: | |
| print(e_msg) | |