Spaces:
Running
Running
from config import * | |
import json | |
import os | |
import pprint as pp | |
import random | |
from datetime import date | |
from pathlib import Path | |
import numpy as np | |
import torch | |
import torch.backends.cudnn as cudnn | |
from torch import optim as optim | |
def setup_train(args): | |
set_up_gpu(args) | |
export_root = create_experiment_export_folder(args) | |
export_experiments_config_as_json(args, export_root) | |
pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1) | |
return export_root | |
def create_experiment_export_folder(args): | |
experiment_dir, experiment_description = args.experiment_dir, args.experiment_description | |
if not os.path.exists(experiment_dir): | |
os.mkdir(experiment_dir) | |
experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description) | |
os.mkdir(experiment_path) | |
print('Folder created: ' + os.path.abspath(experiment_path)) | |
return experiment_path | |
def get_name_of_experiment_path(experiment_dir, experiment_description): | |
experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today()))) | |
idx = _get_experiment_index(experiment_path) | |
experiment_path = experiment_path + "_" + str(idx) | |
return experiment_path | |
def _get_experiment_index(experiment_path): | |
idx = 0 | |
while os.path.exists(experiment_path + "_" + str(idx)): | |
idx += 1 | |
return idx | |
def load_weights(model, path): | |
pass | |
def save_test_result(export_root, result): | |
filepath = Path(export_root).joinpath('test_result.txt') | |
with filepath.open('w') as f: | |
json.dump(result, f, indent=2) | |
def export_experiments_config_as_json(args, experiment_path): | |
with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile: | |
json.dump(vars(args), outfile, indent=2) | |
def fix_random_seed_as(random_seed): | |
random.seed(random_seed) | |
torch.manual_seed(random_seed) | |
torch.cuda.manual_seed_all(random_seed) | |
np.random.seed(random_seed) | |
cudnn.deterministic = True | |
cudnn.benchmark = False | |
def set_up_gpu(args): | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx | |
args.num_gpu = len(args.device_idx.split(",")) | |
def load_pretrained_weights(model, path): | |
chk_dict = torch.load(os.path.abspath(path)) | |
model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict'] | |
model.load_state_dict(model_state_dict) | |
def setup_to_resume(args, model, optimizer): | |
chk_dict = torch.load(os.path.join(os.path.abspath(args.resume_training), 'models/checkpoint-recent.pth')) | |
model.load_state_dict(chk_dict[STATE_DICT_KEY]) | |
optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY]) | |
def create_optimizer(model, args): | |
if args.optimizer == 'Adam': | |
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) | |
class AverageMeterSet(object): | |
def __init__(self, meters=None): | |
self.meters = meters if meters else {} | |
def __getitem__(self, key): | |
if key not in self.meters: | |
meter = AverageMeter() | |
meter.update(0) | |
return meter | |
return self.meters[key] | |
def update(self, name, value, n=1): | |
if name not in self.meters: | |
self.meters[name] = AverageMeter() | |
self.meters[name].update(value, n) | |
def reset(self): | |
for meter in self.meters.values(): | |
meter.reset() | |
def values(self, format_string='{}'): | |
return {format_string.format(name): meter.val for name, meter in self.meters.items()} | |
def averages(self, format_string='{}'): | |
return {format_string.format(name): meter.avg for name, meter in self.meters.items()} | |
def sums(self, format_string='{}'): | |
return {format_string.format(name): meter.sum for name, meter in self.meters.items()} | |
def counts(self, format_string='{}'): | |
return {format_string.format(name): meter.count for name, meter in self.meters.items()} | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val | |
self.count += n | |
self.avg = self.sum / self.count | |
def __format__(self, format): | |
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format) | |