Spaces:
Runtime error
Runtime error
| # Copyright (c) SenseTime Research. All rights reserved. | |
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Train a GAN using the techniques described in the paper | |
| "Training Generative Adversarial Networks with Limited Data".""" | |
| import os | |
| import click | |
| import re | |
| import json | |
| import tempfile | |
| import torch | |
| import dnnlib | |
| import ast | |
| from training import training_loop | |
| from metrics import metric_main | |
| from torch_utils import training_stats | |
| from torch_utils import custom_ops | |
| # ---------------------------------------------------------------------------- | |
| class UserError(Exception): | |
| pass | |
| # ---------------------------------------------------------------------------- | |
| def setup_training_loop_kwargs( | |
| # General options (not included in desc). | |
| gpus=None, # Number of GPUs: <int>, default = 1 gpu | |
| snap=None, # Snapshot interval: <int>, default = 50 ticks | |
| metrics=None, # List of metric names: [], ['fid50k_full'] (default), ... | |
| seed=None, # Random seed: <int>, default = 0 | |
| # Dataset. | |
| data=None, # Training dataset (required): <path> | |
| cond=None, # Train conditional model based on dataset labels: <bool>, default = False | |
| subset=None, # Train with only N images: <int>, default = all | |
| mirror=None, # Augment dataset with x-flips: <bool>, default = False | |
| square=None, | |
| # Base config. | |
| # Base config: 'auto' (default), 'stylegan2', 'paper256', 'paper512', 'paper1024', 'cifar', 'shhq' | |
| cfg=None, | |
| gamma=None, # Override R1 gamma: <float> | |
| kimg=None, # Override training duration: <int> | |
| batch=None, # Override batch size: <int> | |
| # Discriminator augmentation. | |
| aug=None, # Augmentation mode: 'ada' (default), 'noaug', 'fixed' | |
| p=None, # Specify p for 'fixed' (required): <float> | |
| target=None, # Override ADA target for 'ada': <float>, default = depends on aug | |
| # Augmentation pipeline: 'blit', 'geom', 'color', 'filter', 'noise', 'cutout', 'bg', 'bgc' (default), ..., 'bgcfnc' | |
| augpipe=None, | |
| # Transfer learning. | |
| # Load previous network: 'noresume' (default), 'ffhq256', 'ffhq512', 'ffhq1024', 'celebahq256', 'lsundog256', <file>, <url> | |
| resume=None, | |
| freezed=None, # Freeze-D: <int>, default = 0 discriminator layers | |
| # Performance options (not included in desc). | |
| fp32=None, # Disable mixed-precision training: <bool>, default = False | |
| nhwc=None, # Use NHWC memory format with FP16: <bool>, default = False | |
| # Allow PyTorch to use TF32 for matmul and convolutions: <bool>, default = False | |
| allow_tf32=None, | |
| nobench=None, # Disable cuDNN benchmarking: <bool>, default = False | |
| workers=None, # Override number of DataLoader workers: <int>, default = 3 | |
| ): | |
| args = dnnlib.EasyDict() | |
| # ------------------------------------------ | |
| # General options: gpus, snap, metrics, seed | |
| # ------------------------------------------ | |
| if gpus is None: | |
| gpus = 1 | |
| assert isinstance(gpus, int) | |
| if not (gpus >= 1 and gpus & (gpus - 1) == 0): | |
| raise UserError('--gpus must be a power of two') | |
| args.num_gpus = gpus | |
| if snap is None: | |
| snap = 50 | |
| assert isinstance(snap, int) | |
| if snap < 1: | |
| raise UserError('--snap must be at least 1') | |
| args.image_snapshot_ticks = snap | |
| args.network_snapshot_ticks = snap | |
| if metrics is None: | |
| metrics = ['fid50k_full'] | |
| assert isinstance(metrics, list) | |
| if not all(metric_main.is_valid_metric(metric) for metric in metrics): | |
| raise UserError('\n'.join( | |
| ['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) | |
| args.metrics = metrics | |
| if seed is None: | |
| seed = 0 | |
| assert isinstance(seed, int) | |
| args.random_seed = seed | |
| # ------------------------------------------- | |
| # Dataset: data, cond, subset, mirror, square | |
| # ------------------------------------------- | |
| print('square : ', square) | |
| assert data is not None | |
| assert isinstance(data, str) | |
| args.training_set_kwargs = dnnlib.EasyDict( | |
| class_name='training.dataset.ImageFolderDataset', path=data, use_labels=True, max_size=None, xflip=False, square=square) | |
| args.data_loader_kwargs = dnnlib.EasyDict( | |
| pin_memory=True, num_workers=3, prefetch_factor=2) | |
| try: | |
| training_set = dnnlib.util.construct_class_by_name( | |
| **args.training_set_kwargs) # subclass of training.dataset.Dataset | |
| # be explicit about resolution | |
| args.training_set_kwargs.resolution = training_set.resolution | |
| # be explicit about labels | |
| args.training_set_kwargs.use_labels = training_set.has_labels | |
| args.training_set_kwargs.max_size = len( | |
| training_set) # be explicit about dataset size | |
| desc = training_set.name | |
| print('desc: ', desc) | |
| del training_set # conserve memory | |
| except IOError as err: | |
| raise UserError(f'--data: {err}') | |
| if square: | |
| desc += '-square' | |
| else: | |
| desc += '-rectangle' | |
| if cond is None: | |
| cond = False | |
| assert isinstance(cond, bool) | |
| if cond: | |
| if not args.training_set_kwargs.use_labels: | |
| raise UserError( | |
| '--cond=True requires labels specified in dataset.json') | |
| desc += '-cond' | |
| else: | |
| args.training_set_kwargs.use_labels = False | |
| if subset is not None: | |
| assert isinstance(subset, int) | |
| if not 1 <= subset <= args.training_set_kwargs.max_size: | |
| raise UserError( | |
| f'--subset must be between 1 and {args.training_set_kwargs.max_size}') | |
| desc += f'-subset{subset}' | |
| if subset < args.training_set_kwargs.max_size: | |
| args.training_set_kwargs.max_size = subset | |
| args.training_set_kwargs.random_seed = args.random_seed | |
| if mirror is None: | |
| mirror = False | |
| assert isinstance(mirror, bool) | |
| if mirror: | |
| desc += '-mirror' | |
| args.training_set_kwargs.xflip = True | |
| # ------------------------------------ | |
| # Base config: cfg, gamma, kimg, batch | |
| # ------------------------------------ | |
| if cfg is None: | |
| cfg = 'auto' | |
| assert isinstance(cfg, str) | |
| desc += f'-{cfg}' | |
| cfg_specs = { | |
| 'auto': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=2), | |
| # Populated dynamically based on resolution and GPU count. | |
| 'shhq': dict(ref_gpus=-1, kimg=25000, mb=-1, mbstd=-1, fmaps=-1, lrate=-1, gamma=-1, ema=-1, ramp=0.05, map=8), | |
| # Uses mixed-precision, unlike the original StyleGAN2. | |
| 'stylegan2': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=10, ema=10, ramp=None, map=8), | |
| 'paper256': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=0.5, lrate=0.0025, gamma=1, ema=20, ramp=None, map=8), | |
| 'paper512': dict(ref_gpus=8, kimg=25000, mb=64, mbstd=8, fmaps=1, lrate=0.0025, gamma=0.5, ema=20, ramp=None, map=8), | |
| 'paper1024': dict(ref_gpus=8, kimg=25000, mb=32, mbstd=4, fmaps=1, lrate=0.002, gamma=2, ema=10, ramp=None, map=8), | |
| 'cifar': dict(ref_gpus=2, kimg=100000, mb=64, mbstd=32, fmaps=1, lrate=0.0025, gamma=0.01, ema=500, ramp=0.05, map=2), | |
| } | |
| assert cfg in cfg_specs | |
| spec = dnnlib.EasyDict(cfg_specs[cfg]) | |
| if cfg == 'auto' or cfg == 'shhq': | |
| desc += f'{gpus:d}' | |
| spec.ref_gpus = gpus | |
| res = args.training_set_kwargs.resolution | |
| # keep gpu memory consumption at bay | |
| spec.mb = max(min(gpus * min(4096 // res, 32), 64), gpus) | |
| # other hyperparams behave more predictably if mbstd group size remains fixed | |
| spec.mbstd = min(spec.mb // gpus, 4) | |
| spec.fmaps = 1 if res >= 512 else 0.5 | |
| spec.lrate = 0.002 if res >= 1024 else 0.0025 | |
| spec.gamma = 0.0002 * (res ** 2) / spec.mb # heuristic formula | |
| spec.ema = spec.mb * 10 / 32 | |
| args.G_kwargs = dnnlib.EasyDict(class_name='training.networks.Generator', z_dim=512, w_dim=512, | |
| mapping_kwargs=dnnlib.EasyDict(), synthesis_kwargs=dnnlib.EasyDict(), square=square) | |
| args.D_kwargs = dnnlib.EasyDict(class_name='training.networks.Discriminator', block_kwargs=dnnlib.EasyDict( | |
| ), mapping_kwargs=dnnlib.EasyDict(), epilogue_kwargs=dnnlib.EasyDict(), square=square) | |
| args.G_kwargs.synthesis_kwargs.channel_base = args.D_kwargs.channel_base = int( | |
| spec.fmaps * 32768) | |
| args.G_kwargs.synthesis_kwargs.channel_max = args.D_kwargs.channel_max = 512 | |
| args.G_kwargs.mapping_kwargs.num_layers = spec.map | |
| # enable mixed-precision training | |
| args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 4 | |
| # clamp activations to avoid float16 overflow | |
| args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = 256 | |
| args.D_kwargs.epilogue_kwargs.mbstd_group_size = spec.mbstd | |
| args.G_opt_kwargs = dnnlib.EasyDict( | |
| class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) | |
| args.D_opt_kwargs = dnnlib.EasyDict( | |
| class_name='torch.optim.Adam', lr=spec.lrate, betas=[0, 0.99], eps=1e-8) | |
| args.loss_kwargs = dnnlib.EasyDict( | |
| class_name='training.loss.StyleGAN2Loss', r1_gamma=spec.gamma) | |
| args.total_kimg = spec.kimg | |
| args.batch_size = spec.mb | |
| args.batch_gpu = spec.mb // spec.ref_gpus | |
| args.ema_kimg = spec.ema | |
| args.ema_rampup = spec.ramp | |
| if cfg == 'cifar': | |
| args.loss_kwargs.pl_weight = 0 # disable path length regularization | |
| args.loss_kwargs.style_mixing_prob = 0 # disable style mixing | |
| args.D_kwargs.architecture = 'orig' # disable residual skip connections | |
| if gamma is not None: | |
| assert isinstance(gamma, float) | |
| if not gamma >= 0: | |
| raise UserError('--gamma must be non-negative') | |
| desc += f'-gamma{gamma:g}' | |
| args.loss_kwargs.r1_gamma = gamma | |
| if kimg is not None: | |
| assert isinstance(kimg, int) | |
| if not kimg >= 1: | |
| raise UserError('--kimg must be at least 1') | |
| desc += f'-kimg{kimg:d}' | |
| args.total_kimg = kimg | |
| if batch is not None: | |
| assert isinstance(batch, int) | |
| if not (batch >= 1 and batch % gpus == 0): | |
| raise UserError( | |
| '--batch must be at least 1 and divisible by --gpus') | |
| desc += f'-batch{batch}' | |
| args.batch_size = batch | |
| args.batch_gpu = batch // gpus | |
| # --------------------------------------------------- | |
| # Discriminator augmentation: aug, p, target, augpipe | |
| # --------------------------------------------------- | |
| if aug is None: | |
| aug = 'ada' | |
| else: | |
| assert isinstance(aug, str) | |
| desc += f'-{aug}' | |
| if aug == 'ada': | |
| args.ada_target = 0.6 | |
| elif aug == 'noaug': | |
| pass | |
| elif aug == 'fixed': | |
| if p is None: | |
| raise UserError(f'--aug={aug} requires specifying --p') | |
| else: | |
| raise UserError(f'--aug={aug} not supported') | |
| if p is not None: | |
| assert isinstance(p, float) | |
| if aug != 'fixed': | |
| raise UserError('--p can only be specified with --aug=fixed') | |
| if not 0 <= p <= 1: | |
| raise UserError('--p must be between 0 and 1') | |
| desc += f'-p{p:g}' | |
| args.augment_p = p | |
| if target is not None: | |
| assert isinstance(target, float) | |
| if aug != 'ada': | |
| raise UserError('--target can only be specified with --aug=ada') | |
| if not 0 <= target <= 1: | |
| raise UserError('--target must be between 0 and 1') | |
| desc += f'-target{target:g}' | |
| args.ada_target = target | |
| assert augpipe is None or isinstance(augpipe, str) | |
| if augpipe is None: | |
| augpipe = 'bgc' | |
| else: | |
| if aug == 'noaug': | |
| raise UserError('--augpipe cannot be specified with --aug=noaug') | |
| desc += f'-{augpipe}' | |
| augpipe_specs = { | |
| 'blit': dict(xflip=1, rotate90=1, xint=1), | |
| 'geom': dict(scale=1, rotate=1, aniso=1, xfrac=1), | |
| 'color': dict(brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), | |
| 'filter': dict(imgfilter=1), | |
| 'noise': dict(noise=1), | |
| 'cutout': dict(cutout=1), | |
| 'bg': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1), | |
| 'bgc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1), | |
| 'bgcf': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1), | |
| 'bgcfn': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1), | |
| 'bgcfnc': dict(xflip=1, rotate90=1, xint=1, scale=1, rotate=1, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1, imgfilter=1, noise=1, cutout=1), | |
| 'body': dict(xflip=1, rotate90=0, xint=1, scale=1, rotate=0, aniso=1, xfrac=1, brightness=1, contrast=1, lumaflip=1, hue=1, saturation=1) | |
| } | |
| assert augpipe in augpipe_specs | |
| if aug != 'noaug': | |
| args.augment_kwargs = dnnlib.EasyDict( | |
| class_name='training.augment.AugmentPipe', **augpipe_specs[augpipe]) | |
| # ---------------------------------- | |
| # Transfer learning: resume, freezed | |
| # ---------------------------------- | |
| resume_specs = { | |
| 'ffhq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res256-mirror-paper256-noaug.pkl', | |
| 'ffhq512': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res512-mirror-stylegan2-noaug.pkl', | |
| 'ffhq1024': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/ffhq-res1024-mirror-stylegan2-noaug.pkl', | |
| 'celebahq256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/celebahq-res256-mirror-paper256-kimg100000-ada-target0.5.pkl', | |
| 'lsundog256': 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/transfer-learning-source-nets/lsundog-res256-paper256-kimg100000-noaug.pkl', | |
| } | |
| assert resume is None or isinstance(resume, str) | |
| if resume is None: | |
| resume = 'noresume' | |
| elif resume == 'noresume': | |
| desc += '-noresume' | |
| elif resume in resume_specs: | |
| desc += f'-resume{resume}' | |
| args.resume_pkl = resume_specs[resume] # predefined url | |
| else: | |
| desc += '-resumecustom' | |
| args.resume_pkl = resume # custom path or url | |
| if resume != 'noresume': | |
| args.ada_kimg = 100 # make ADA react faster at the beginning | |
| args.ema_rampup = None # disable EMA rampup | |
| if freezed is not None: | |
| assert isinstance(freezed, int) | |
| if not freezed >= 0: | |
| raise UserError('--freezed must be non-negative') | |
| desc += f'-freezed{freezed:d}' | |
| args.D_kwargs.block_kwargs.freeze_layers = freezed | |
| # ------------------------------------------------- | |
| # Performance options: fp32, nhwc, nobench, workers | |
| # ------------------------------------------------- | |
| if fp32 is None: | |
| fp32 = False | |
| assert isinstance(fp32, bool) | |
| if fp32: | |
| args.G_kwargs.synthesis_kwargs.num_fp16_res = args.D_kwargs.num_fp16_res = 0 | |
| args.G_kwargs.synthesis_kwargs.conv_clamp = args.D_kwargs.conv_clamp = None | |
| if nhwc is None: | |
| nhwc = False | |
| assert isinstance(nhwc, bool) | |
| if nhwc: | |
| args.G_kwargs.synthesis_kwargs.fp16_channels_last = args.D_kwargs.block_kwargs.fp16_channels_last = True | |
| if nobench is None: | |
| nobench = False | |
| assert isinstance(nobench, bool) | |
| if nobench: | |
| args.cudnn_benchmark = False | |
| if allow_tf32 is None: | |
| allow_tf32 = False | |
| assert isinstance(allow_tf32, bool) | |
| if allow_tf32: | |
| args.allow_tf32 = True | |
| if workers is not None: | |
| assert isinstance(workers, int) | |
| if not workers >= 1: | |
| raise UserError('--workers must be at least 1') | |
| args.data_loader_kwargs.num_workers = workers | |
| return desc, args | |
| # ---------------------------------------------------------------------------- | |
| def subprocess_fn(rank, args, temp_dir): | |
| dnnlib.util.Logger(file_name=os.path.join( | |
| args.run_dir, 'log.txt'), file_mode='a', should_flush=True) | |
| # Init torch.distributed. | |
| if args.num_gpus > 1: | |
| init_file = os.path.abspath(os.path.join( | |
| temp_dir, '.torch_distributed_init')) | |
| if os.name == 'nt': | |
| init_method = 'file:///' + init_file.replace('\\', '/') | |
| torch.distributed.init_process_group( | |
| backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) | |
| else: | |
| init_method = f'file://{init_file}' | |
| torch.distributed.init_process_group( | |
| backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) | |
| # Init torch_utils. | |
| sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None | |
| training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) | |
| if rank != 0: | |
| custom_ops.verbosity = 'none' | |
| # Execute training loop. | |
| training_loop.training_loop(rank=rank, **args) | |
| # ---------------------------------------------------------------------------- | |
| class CommaSeparatedList(click.ParamType): | |
| name = 'list' | |
| def convert(self, value, param, ctx): | |
| _ = param, ctx | |
| if value is None or value.lower() == 'none' or value == '': | |
| return [] | |
| return value.split(',') | |
| # ---------------------------------------------------------------------------- | |
| # General options. | |
| # Dataset. | |
| # Base config. | |
| # Discriminator augmentation. | |
| # Transfer learning. | |
| # Performance options. | |
| def main(ctx, outdir, dry_run, **config_kwargs): | |
| """Train a GAN using the techniques described in the paper | |
| "Training Generative Adversarial Networks with Limited Data". | |
| Examples: | |
| \b | |
| # Train with custom dataset using 1 GPU. | |
| python train.py --outdir=~/training-runs --data=~/mydataset.zip --gpus=1 | |
| \b | |
| # Train class-conditional CIFAR-10 using 2 GPUs. | |
| python train.py --outdir=~/training-runs --data=~/datasets/cifar10.zip \\ | |
| --gpus=2 --cfg=cifar --cond=1 | |
| \b | |
| # Transfer learn MetFaces from FFHQ using 4 GPUs. | |
| python train.py --outdir=~/training-runs --data=~/datasets/metfaces.zip \\ | |
| --gpus=4 --cfg=paper1024 --mirror=1 --resume=ffhq1024 --snap=10 | |
| \b | |
| # Reproduce original StyleGAN2 config F. | |
| python train.py --outdir=~/training-runs --data=~/datasets/ffhq.zip \\ | |
| --gpus=8 --cfg=stylegan2 --mirror=1 --aug=noaug | |
| \b | |
| Base configs (--cfg): | |
| auto Automatically select reasonable defaults based on resolution | |
| and GPU count. Good starting point for new datasets. | |
| stylegan2 Reproduce results for StyleGAN2 config F at 1024x1024. | |
| paper256 Reproduce results for FFHQ and LSUN Cat at 256x256. | |
| paper512 Reproduce results for BreCaHAD and AFHQ at 512x512. | |
| paper1024 Reproduce results for MetFaces at 1024x1024. | |
| cifar Reproduce results for CIFAR-10 at 32x32. | |
| \b | |
| Transfer learning source networks (--resume): | |
| ffhq256 FFHQ trained at 256x256 resolution. | |
| ffhq512 FFHQ trained at 512x512 resolution. | |
| ffhq1024 FFHQ trained at 1024x1024 resolution. | |
| celebahq256 CelebA-HQ trained at 256x256 resolution. | |
| lsundog256 LSUN Dog trained at 256x256 resolution. | |
| <PATH or URL> Custom network pickle. | |
| """ | |
| dnnlib.util.Logger(should_flush=True) | |
| # Setup training options. | |
| try: | |
| run_desc, args = setup_training_loop_kwargs(**config_kwargs) | |
| except UserError as err: | |
| ctx.fail(err) | |
| # Pick output directory. | |
| prev_run_dirs = [] | |
| if os.path.isdir(outdir): | |
| prev_run_dirs = [x for x in os.listdir( | |
| outdir) if os.path.isdir(os.path.join(outdir, x))] | |
| prev_run_ids = [re.match(r'^\d+', x) for x in prev_run_dirs] | |
| prev_run_ids = [int(x.group()) for x in prev_run_ids if x is not None] | |
| cur_run_id = max(prev_run_ids, default=-1) + 1 | |
| args.run_dir = os.path.join(outdir, f'{cur_run_id:05d}-{run_desc}') | |
| assert not os.path.exists(args.run_dir) | |
| # Print options. | |
| print() | |
| print('Training options:') | |
| print(json.dumps(args, indent=2)) | |
| print() | |
| print(f'Output directory: {args.run_dir}') | |
| print(f'Training data: {args.training_set_kwargs.path}') | |
| print(f'Training duration: {args.total_kimg} kimg') | |
| print(f'Number of GPUs: {args.num_gpus}') | |
| print(f'Number of images: {args.training_set_kwargs.max_size}') | |
| print(f'Image resolution: {args.training_set_kwargs.resolution}') | |
| print(f'Conditional model: {args.training_set_kwargs.use_labels}') | |
| print(f'Dataset x-flips: {args.training_set_kwargs.xflip}') | |
| print() | |
| # Dry run? | |
| if dry_run: | |
| print('Dry run; exiting.') | |
| return | |
| # Create output directory. | |
| print('Creating output directory...') | |
| os.makedirs(args.run_dir, exist_ok=True) | |
| with open(os.path.join(args.run_dir, 'training_options.json'), 'wt') as f: | |
| json.dump(args, f, indent=2) | |
| # Launch processes. | |
| print('Launching processes...') | |
| torch.multiprocessing.set_start_method('spawn') | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| if args.num_gpus == 1: | |
| subprocess_fn(rank=0, args=args, temp_dir=temp_dir) | |
| else: | |
| torch.multiprocessing.spawn(fn=subprocess_fn, args=( | |
| args, temp_dir), nprocs=args.num_gpus) | |
| # ---------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| main() # pylint: disable=no-value-for-parameter | |
| # ---------------------------------------------------------------------------- | |