Spaces:
Running
Running
File size: 4,645 Bytes
426ffb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
from config import *
import json
import os
import pprint as pp
import random
from datetime import date
from pathlib import Path
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from torch import optim as optim
def setup_train(args):
set_up_gpu(args)
export_root = create_experiment_export_folder(args)
export_experiments_config_as_json(args, export_root)
pp.pprint({k: v for k, v in vars(args).items() if v is not None}, width=1)
return export_root
def create_experiment_export_folder(args):
experiment_dir, experiment_description = args.experiment_dir, args.experiment_description
if not os.path.exists(experiment_dir):
os.mkdir(experiment_dir)
experiment_path = get_name_of_experiment_path(experiment_dir, experiment_description)
os.mkdir(experiment_path)
print('Folder created: ' + os.path.abspath(experiment_path))
return experiment_path
def get_name_of_experiment_path(experiment_dir, experiment_description):
experiment_path = os.path.join(experiment_dir, (experiment_description + "_" + str(date.today())))
idx = _get_experiment_index(experiment_path)
experiment_path = experiment_path + "_" + str(idx)
return experiment_path
def _get_experiment_index(experiment_path):
idx = 0
while os.path.exists(experiment_path + "_" + str(idx)):
idx += 1
return idx
def load_weights(model, path):
pass
def save_test_result(export_root, result):
filepath = Path(export_root).joinpath('test_result.txt')
with filepath.open('w') as f:
json.dump(result, f, indent=2)
def export_experiments_config_as_json(args, experiment_path):
with open(os.path.join(experiment_path, 'config.json'), 'w') as outfile:
json.dump(vars(args), outfile, indent=2)
def fix_random_seed_as(random_seed):
random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
np.random.seed(random_seed)
cudnn.deterministic = True
cudnn.benchmark = False
def set_up_gpu(args):
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_idx
args.num_gpu = len(args.device_idx.split(","))
def load_pretrained_weights(model, path):
chk_dict = torch.load(os.path.abspath(path))
model_state_dict = chk_dict[STATE_DICT_KEY] if STATE_DICT_KEY in chk_dict else chk_dict['state_dict']
model.load_state_dict(model_state_dict)
def setup_to_resume(args, model, optimizer):
chk_dict = torch.load(os.path.join(os.path.abspath(args.resume_training), 'models/checkpoint-recent.pth'))
model.load_state_dict(chk_dict[STATE_DICT_KEY])
optimizer.load_state_dict(chk_dict[OPTIMIZER_STATE_DICT_KEY])
def create_optimizer(model, args):
if args.optimizer == 'Adam':
return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
return optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
class AverageMeterSet(object):
def __init__(self, meters=None):
self.meters = meters if meters else {}
def __getitem__(self, key):
if key not in self.meters:
meter = AverageMeter()
meter.update(0)
return meter
return self.meters[key]
def update(self, name, value, n=1):
if name not in self.meters:
self.meters[name] = AverageMeter()
self.meters[name].update(value, n)
def reset(self):
for meter in self.meters.values():
meter.reset()
def values(self, format_string='{}'):
return {format_string.format(name): meter.val for name, meter in self.meters.items()}
def averages(self, format_string='{}'):
return {format_string.format(name): meter.avg for name, meter in self.meters.items()}
def sums(self, format_string='{}'):
return {format_string.format(name): meter.sum for name, meter in self.meters.items()}
def counts(self, format_string='{}'):
return {format_string.format(name): meter.count for name, meter in self.meters.items()}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val
self.count += n
self.avg = self.sum / self.count
def __format__(self, format):
return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=format)
|