Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- encoding: utf-8 -*- | |
| """ | |
| @Author : Peike Li | |
| @Contact : [email protected] | |
| @File : train.py | |
| @Time : 8/4/19 3:36 PM | |
| @Desc : | |
| @License : This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import os | |
| import json | |
| import timeit | |
| import argparse | |
| import torch | |
| import torch.optim as optim | |
| import torchvision.transforms as transforms | |
| import torch.backends.cudnn as cudnn | |
| from torch.utils import data | |
| import networks | |
| import utils.schp as schp | |
| from datasets.datasets import LIPDataSet | |
| from datasets.target_generation import generate_edge_tensor | |
| from utils.transforms import BGR2RGB_transform | |
| from utils.criterion import CriterionAll | |
| from utils.encoding import DataParallelModel, DataParallelCriterion | |
| from utils.warmup_scheduler import SGDRScheduler | |
| def get_arguments(): | |
| """Parse all the arguments provided from the CLI. | |
| Returns: | |
| A list of parsed arguments. | |
| """ | |
| parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") | |
| # Network Structure | |
| parser.add_argument("--arch", type=str, default='resnet101') | |
| # Data Preference | |
| parser.add_argument("--data-dir", type=str, default='./data/LIP') | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--input-size", type=str, default='473,473') | |
| parser.add_argument("--num-classes", type=int, default=20) | |
| parser.add_argument("--ignore-label", type=int, default=255) | |
| parser.add_argument("--random-mirror", action="store_true") | |
| parser.add_argument("--random-scale", action="store_true") | |
| # Training Strategy | |
| parser.add_argument("--learning-rate", type=float, default=7e-3) | |
| parser.add_argument("--momentum", type=float, default=0.9) | |
| parser.add_argument("--weight-decay", type=float, default=5e-4) | |
| parser.add_argument("--gpu", type=str, default='0,1,2') | |
| parser.add_argument("--start-epoch", type=int, default=0) | |
| parser.add_argument("--epochs", type=int, default=150) | |
| parser.add_argument("--eval-epochs", type=int, default=10) | |
| parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth') | |
| parser.add_argument("--log-dir", type=str, default='./log') | |
| parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar') | |
| parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch') | |
| parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch') | |
| parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar') | |
| parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight') | |
| parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight') | |
| parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight') | |
| return parser.parse_args() | |
| def main(): | |
| args = get_arguments() | |
| print(args) | |
| start_epoch = 0 | |
| cycle_n = 0 | |
| if not os.path.exists(args.log_dir): | |
| os.makedirs(args.log_dir) | |
| with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file: | |
| json.dump(vars(args), opt_file) | |
| gpus = [int(i) for i in args.gpu.split(',')] | |
| if not args.gpu == 'None': | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
| input_size = list(map(int, args.input_size.split(','))) | |
| cudnn.enabled = True | |
| cudnn.benchmark = True | |
| # Model Initialization | |
| AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) | |
| model = DataParallelModel(AugmentCE2P) | |
| model.cuda() | |
| IMAGE_MEAN = AugmentCE2P.mean | |
| IMAGE_STD = AugmentCE2P.std | |
| INPUT_SPACE = AugmentCE2P.input_space | |
| print('image mean: {}'.format(IMAGE_MEAN)) | |
| print('image std: {}'.format(IMAGE_STD)) | |
| print('input space:{}'.format(INPUT_SPACE)) | |
| restore_from = args.model_restore | |
| if os.path.exists(restore_from): | |
| print('Resume training from {}'.format(restore_from)) | |
| checkpoint = torch.load(restore_from) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| start_epoch = checkpoint['epoch'] | |
| SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) | |
| schp_model = DataParallelModel(SCHP_AugmentCE2P) | |
| schp_model.cuda() | |
| if os.path.exists(args.schp_restore): | |
| print('Resuming schp checkpoint from {}'.format(args.schp_restore)) | |
| schp_checkpoint = torch.load(args.schp_restore) | |
| schp_model_state_dict = schp_checkpoint['state_dict'] | |
| cycle_n = schp_checkpoint['cycle_n'] | |
| schp_model.load_state_dict(schp_model_state_dict) | |
| # Loss Function | |
| criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c, | |
| num_classes=args.num_classes) | |
| criterion = DataParallelCriterion(criterion) | |
| criterion.cuda() | |
| # Data Loader | |
| if INPUT_SPACE == 'BGR': | |
| print('BGR Transformation') | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=IMAGE_MEAN, | |
| std=IMAGE_STD), | |
| ]) | |
| elif INPUT_SPACE == 'RGB': | |
| print('RGB Transformation') | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| BGR2RGB_transform(), | |
| transforms.Normalize(mean=IMAGE_MEAN, | |
| std=IMAGE_STD), | |
| ]) | |
| train_dataset = LIPDataSet(args.data_dir, 'train', crop_size=input_size, transform=transform) | |
| train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus), | |
| num_workers=16, shuffle=True, pin_memory=True, drop_last=True) | |
| print('Total training samples: {}'.format(len(train_dataset))) | |
| # Optimizer Initialization | |
| optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, | |
| weight_decay=args.weight_decay) | |
| lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs, | |
| eta_min=args.learning_rate / 100, warmup_epoch=10, | |
| start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2, | |
| cyclical_epoch=args.cycle_epochs) | |
| total_iters = args.epochs * len(train_loader) | |
| start = timeit.default_timer() | |
| for epoch in range(start_epoch, args.epochs): | |
| lr_scheduler.step(epoch=epoch) | |
| lr = lr_scheduler.get_lr()[0] | |
| model.train() | |
| for i_iter, batch in enumerate(train_loader): | |
| i_iter += len(train_loader) * epoch | |
| images, labels, _ = batch | |
| labels = labels.cuda(non_blocking=True) | |
| edges = generate_edge_tensor(labels) | |
| labels = labels.type(torch.cuda.LongTensor) | |
| edges = edges.type(torch.cuda.LongTensor) | |
| preds = model(images) | |
| # Online Self Correction Cycle with Label Refinement | |
| if cycle_n >= 1: | |
| with torch.no_grad(): | |
| soft_preds = schp_model(images) | |
| soft_parsing = [] | |
| soft_edge = [] | |
| for soft_pred in soft_preds: | |
| soft_parsing.append(soft_pred[0][-1]) | |
| soft_edge.append(soft_pred[1][-1]) | |
| soft_preds = torch.cat(soft_parsing, dim=0) | |
| soft_edges = torch.cat(soft_edge, dim=0) | |
| else: | |
| soft_preds = None | |
| soft_edges = None | |
| loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if i_iter % 100 == 0: | |
| print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr, | |
| loss.data.cpu().numpy())) | |
| if (epoch + 1) % (args.eval_epochs) == 0: | |
| schp.save_schp_checkpoint({ | |
| 'epoch': epoch + 1, | |
| 'state_dict': model.state_dict(), | |
| }, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1)) | |
| # Self Correction Cycle with Model Aggregation | |
| if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0: | |
| print('Self-correction cycle number {}'.format(cycle_n)) | |
| schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1)) | |
| cycle_n += 1 | |
| schp.bn_re_estimate(train_loader, schp_model) | |
| schp.save_schp_checkpoint({ | |
| 'state_dict': schp_model.state_dict(), | |
| 'cycle_n': cycle_n, | |
| }, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n)) | |
| torch.cuda.empty_cache() | |
| end = timeit.default_timer() | |
| print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs, | |
| (end - start) / (epoch - start_epoch + 1))) | |
| end = timeit.default_timer() | |
| print('Training Finished in {} seconds'.format(end - start)) | |
| if __name__ == '__main__': | |
| main() | |