Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import os | |
| import random | |
| import time | |
| import torch | |
| from os import path as osp | |
| from .dist_util import master_only | |
| def set_random_seed(seed): | |
| """Set random seeds.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_time_str(): | |
| return time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
| def mkdir_and_rename(path): | |
| """mkdirs. If path exists, rename it with timestamp and create a new one. | |
| Args: | |
| path (str): Folder path. | |
| """ | |
| if osp.exists(path): | |
| new_name = path + '_archived_' + get_time_str() | |
| print(f'Path already exists. Rename it to {new_name}', flush=True) | |
| os.rename(path, new_name) | |
| os.makedirs(path, exist_ok=True) | |
| def make_exp_dirs(opt): | |
| """Make dirs for experiments.""" | |
| path_opt = opt['path'].copy() | |
| if opt['is_train']: | |
| mkdir_and_rename(path_opt.pop('experiments_root')) | |
| else: | |
| mkdir_and_rename(path_opt.pop('results_root')) | |
| for key, path in path_opt.items(): | |
| if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): | |
| continue | |
| else: | |
| os.makedirs(path, exist_ok=True) | |
| def scandir(dir_path, suffix=None, recursive=False, full_path=False): | |
| """Scan a directory to find the interested files. | |
| Args: | |
| dir_path (str): Path of the directory. | |
| suffix (str | tuple(str), optional): File suffix that we are | |
| interested in. Default: None. | |
| recursive (bool, optional): If set to True, recursively scan the | |
| directory. Default: False. | |
| full_path (bool, optional): If set to True, include the dir_path. | |
| Default: False. | |
| Returns: | |
| A generator for all the interested files with relative paths. | |
| """ | |
| if (suffix is not None) and not isinstance(suffix, (str, tuple)): | |
| raise TypeError('"suffix" must be a string or tuple of strings') | |
| root = dir_path | |
| def _scandir(dir_path, suffix, recursive): | |
| for entry in os.scandir(dir_path): | |
| if not entry.name.startswith('.') and entry.is_file(): | |
| if full_path: | |
| return_path = entry.path | |
| else: | |
| return_path = osp.relpath(entry.path, root) | |
| if suffix is None: | |
| yield return_path | |
| elif return_path.endswith(suffix): | |
| yield return_path | |
| else: | |
| if recursive: | |
| yield from _scandir(entry.path, suffix=suffix, recursive=recursive) | |
| else: | |
| continue | |
| return _scandir(dir_path, suffix=suffix, recursive=recursive) | |
| def check_resume(opt, resume_iter): | |
| """Check resume states and pretrain_network paths. | |
| Args: | |
| opt (dict): Options. | |
| resume_iter (int): Resume iteration. | |
| """ | |
| if opt['path']['resume_state']: | |
| # get all the networks | |
| networks = [key for key in opt.keys() if key.startswith('network_')] | |
| flag_pretrain = False | |
| for network in networks: | |
| if opt['path'].get(f'pretrain_{network}') is not None: | |
| flag_pretrain = True | |
| if flag_pretrain: | |
| print('pretrain_network path will be ignored during resuming.') | |
| # set pretrained model paths | |
| for network in networks: | |
| name = f'pretrain_{network}' | |
| basename = network.replace('network_', '') | |
| if opt['path'].get('ignore_resume_networks') is None or (network | |
| not in opt['path']['ignore_resume_networks']): | |
| opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') | |
| print(f"Set {name} to {opt['path'][name]}") | |
| # change param_key to params in resume | |
| param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] | |
| for param_key in param_keys: | |
| if opt['path'][param_key] == 'params_ema': | |
| opt['path'][param_key] = 'params' | |
| print(f'Set {param_key} to params') | |
| def sizeof_fmt(size, suffix='B'): | |
| """Get human readable file size. | |
| Args: | |
| size (int): File size. | |
| suffix (str): Suffix. Default: 'B'. | |
| Return: | |
| str: Formatted file size. | |
| """ | |
| for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: | |
| if abs(size) < 1024.0: | |
| return f'{size:3.1f} {unit}{suffix}' | |
| size /= 1024.0 | |
| return f'{size:3.1f} Y{suffix}' | |