Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import os | |
| import time | |
| import copy | |
| import json | |
| import dill as pickle | |
| import psutil | |
| import PIL.Image | |
| import numpy as np | |
| import torch | |
| import dnnlib | |
| from torch_utils import misc | |
| from torch_utils import training_stats | |
| from torch_utils.ops import conv2d_gradfix | |
| from torch_utils.ops import grid_sample_gradfix | |
| from torchvision.utils import save_image | |
| import math | |
| import legacy | |
| from metrics import metric_main | |
| import torch.nn.functional as F | |
| np.set_printoptions(formatter={'float': '{:0.2f}'.format}) | |
| from collections import Counter | |
| #---------------------------------------------------------------------------- | |
| class SparsestVector: | |
| def __init__(self): | |
| self.sparsest_vector = None | |
| def add(self, vector): | |
| """Add a vector, only keeping it if it is sparser than the current stored one.""" | |
| if self.sparsest_vector is None: | |
| self.sparsest_vector = vector | |
| else: | |
| current_nonzero = torch.count_nonzero(self.sparsest_vector).item() | |
| new_nonzero = torch.count_nonzero(vector).item() | |
| # Keep the new vector only if it's sparser (fewer non-zero elements) | |
| if new_nonzero < current_nonzero: | |
| self.sparsest_vector = vector | |
| def check(self): | |
| """Returns the sparsest vector currently stored.""" | |
| return self.sparsest_vector | |
| def setup_snapshot_image_grid(training_set, random_seed=0): | |
| rnd = np.random.RandomState(random_seed) | |
| gw = int(np.clip(768*2 // training_set.image_shape[2], 7, 32)) | |
| gh = int(np.clip(432*2 // training_set.image_shape[1], 4, 32)) | |
| # No labels => show random subset of training samples. | |
| if not training_set.has_labels: | |
| all_indices = list(range(len(training_set))) | |
| rnd.shuffle(all_indices) | |
| grid_indices = [all_indices[i % len(all_indices)] for i in range(gw * gh)] | |
| label_groups = [] | |
| else: | |
| # Group training samples by label. | |
| label_groups = dict() # label => [idx, ...] | |
| for idx in range(len(training_set)): | |
| label = tuple(training_set.get_details(idx).raw_label.flat[::-1]) | |
| if label not in label_groups: | |
| label_groups[label] = [] | |
| label_groups[label].append(idx) | |
| if training_set.image_shape[1] < 256: | |
| gw *= 2 | |
| gh *= len(label_groups) | |
| #gw = min(gw, 16) | |
| # Reorder. | |
| label_order = sorted(label_groups.keys()) | |
| for label in label_order: | |
| rnd.shuffle(label_groups[label]) | |
| # Organize into grid. | |
| grid_indices = [] | |
| for y in range(len(label_groups)): | |
| label = label_order[y % len(label_order)] | |
| indices = label_groups[label] | |
| grid_indices += [indices[x % len(indices)] for x in range(gw)] | |
| label_groups[label] = [indices[(i + gw) % len(indices)] for i in range(len(indices))] | |
| # Load data. | |
| images, labels = zip(*[training_set[i] for i in grid_indices]) | |
| return (gw, len(label_groups)), np.stack(images), np.stack(labels), len(label_groups) | |
| #---------------------------------------------------------------------------- | |
| def save_image_grid(img, fname, drange, grid_size): | |
| lo, hi = drange | |
| img = np.asarray(img, dtype=np.float32) | |
| img = (img - lo) * (255 / (hi - lo)) | |
| img = np.rint(img).clip(0, 255).astype(np.uint8) | |
| gw, gh = grid_size | |
| _N, C, H, W = img.shape | |
| img = img.reshape(gh, gw, C, H, W) | |
| img = img.transpose(0, 3, 1, 4, 2) | |
| img = img.reshape(gh * H, gw * W, C) | |
| assert C in [1, 3] | |
| if C == 1: | |
| PIL.Image.fromarray(img[:, :, 0], 'L').save(fname) | |
| if C == 3: | |
| PIL.Image.fromarray(img, 'RGB').save(fname) | |
| class VectorHistoryChecker: | |
| def __init__(self, b, d, m): | |
| self.b = b | |
| self.d = d | |
| self.m = m | |
| self.history = torch.ones(b, d, m)*1e99 # Initialize history with zeros | |
| self.current_index = 0 | |
| def update_history(self, new_version): | |
| """Update history with the new version of the vector.""" | |
| self.history[:, :, self.current_index] = new_version.cpu() | |
| self.current_index = (self.current_index + 1) % self.m | |
| def check_history(self, input_version): | |
| """Check if the input version matches all m history versions for each row.""" | |
| consistency = torch.ones(self.b, dtype=torch.bool) # Initialize as True for all rows | |
| for i in range(self.m): | |
| # Check row-wise equality across the history | |
| consistency &= torch.all(self.history[:, :, i] == input_version.cpu(), dim=1) | |
| return consistency | |
| def get_history(self): | |
| """Get the current history.""" | |
| return self.history | |
| class ColumnHistoryChecker: | |
| def __init__(self, b, d, m): | |
| self.b = b | |
| self.d = d | |
| self.m = m | |
| self.history = torch.ones(b, d, m)*1e99 # Initialize history with zeros | |
| self.current_index = 0 | |
| def update_history(self, new_version): | |
| """Update history with the new version of the vector.""" | |
| self.history[:, :, self.current_index] = new_version.cpu() | |
| self.current_index = (self.current_index + 1) % self.m | |
| def check_history(self, input_version): | |
| """Check if the input version matches all m history versions for each row.""" | |
| consistency = torch.ones(self.d, dtype=torch.bool) # Initialize as True for all rows | |
| for i in range(self.m): | |
| # Check column-wise equality across the history | |
| consistency &= torch.all(self.history[:, :, i] == input_version.cpu(), dim=0) | |
| return consistency | |
| def get_history(self): | |
| """Get the current history.""" | |
| return self.history | |
| #---------------------------------------------------------------------------- | |
| def training_loop( | |
| run_dir = '.', # Output directory. | |
| training_set_kwargs = {}, # Options for training set. | |
| data_loader_kwargs = {}, # Options for torch.utils.data.DataLoader. | |
| G_kwargs = {}, # Options for generator network. | |
| D_kwargs = {}, # Options for discriminator network. | |
| G_opt_kwargs = {}, # Options for generator optimizer. | |
| D_opt_kwargs = {}, # Options for discriminator optimizer. | |
| augment_kwargs = None, # Options for augmentation pipeline. None = disable. | |
| loss_kwargs = {}, # Options for loss function. | |
| metrics = [], # Metrics to evaluate during training. | |
| random_seed = 0, # Global random seed. | |
| num_gpus = 1, # Number of GPUs participating in the training. | |
| rank = 0, # Rank of the current process in [0, num_gpus[. | |
| batch_size = 4, # Total batch size for one training iteration. Can be larger than batch_gpu * num_gpus. | |
| batch_gpu = 4, # Number of samples processed at a time by one GPU. | |
| ema_kimg = 10, # Half-life of the exponential moving average (EMA) of generator weights. | |
| ema_rampup = None, # EMA ramp-up coefficient. | |
| G_reg_interval = 4, # How often to perform regularization for G? None = disable lazy regularization. | |
| D_reg_interval = 16, # How often to perform regularization for D? None = disable lazy regularization. | |
| augment_p = 0, # Initial value of augmentation probability. | |
| ada_target = None, # ADA target value. None = fixed p. | |
| ada_interval = 4, # How often to perform ADA adjustment? | |
| ada_kimg = 500, # ADA adjustment speed, measured in how many kimg it takes for p to increase/decrease by one unit. | |
| total_kimg = 25000, # Total length of the training, measured in thousands of real images. | |
| kimg_per_tick = 4, # Progress snapshot interval. | |
| image_snapshot_ticks = 50, # How often to save image snapshots? None = disable. | |
| network_snapshot_ticks = 50, # How often to save network snapshots? None = disable. | |
| resume_pkl = None, # Network pickle to resume training from. | |
| cudnn_benchmark = True, # Enable torch.backends.cudnn.benchmark? | |
| allow_tf32 = False, # Enable torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.allow_tf32? | |
| abort_fn = None, # Callback function for determining whether to abort training. Must return consistent results across ranks. | |
| progress_fn = None, # Callback function for updating training progress. Called for all ranks. | |
| lambda_sparse = None, | |
| lambda_entropy = None, | |
| lambda_ortho = None, | |
| lambda_colvar = None, | |
| lambda_rowvar = None, | |
| lambda_equal = None, | |
| lambda_epsilon = None, | |
| lambda_path=None, | |
| g_iter=None, | |
| temperature=1, | |
| ): | |
| # Initialize. | |
| start_time = time.time() | |
| device = torch.device('cuda', rank) | |
| np.random.seed(random_seed * num_gpus + rank) | |
| torch.manual_seed(random_seed * num_gpus + rank) | |
| torch.backends.cudnn.benchmark = cudnn_benchmark # Improves training speed. | |
| torch.backends.cuda.matmul.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for matmul | |
| torch.backends.cudnn.allow_tf32 = allow_tf32 # Allow PyTorch to internally use tf32 for convolutions | |
| conv2d_gradfix.enabled = True # Improves training speed. | |
| grid_sample_gradfix.enabled = True # Avoids errors with the augmentation pipe. | |
| # Load training set. | |
| if rank == 0: | |
| print('Loading training set...') | |
| training_set = dnnlib.util.construct_class_by_name(**training_set_kwargs) # subclass of training.dataset.Dataset | |
| training_set_sampler = misc.InfiniteSampler(dataset=training_set, rank=rank, num_replicas=num_gpus, seed=random_seed) | |
| training_set_iterator = iter(torch.utils.data.DataLoader(dataset=training_set, sampler=training_set_sampler, batch_size=batch_size//num_gpus, **data_loader_kwargs)) | |
| if rank == 0: | |
| print() | |
| print('Num images: ', len(training_set)) | |
| print('Image shape:', training_set.image_shape) | |
| print('Label shape:', training_set.label_shape) | |
| print() | |
| # Construct networks. | |
| if rank == 0: | |
| print('Constructing networks...') | |
| common_kwargs = dict(c_dim=training_set.label_dim, img_resolution=training_set.resolution, img_channels=training_set.num_channels) | |
| G = dnnlib.util.construct_class_by_name(**G_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module | |
| D = dnnlib.util.construct_class_by_name(**D_kwargs, **common_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module | |
| G_ema = copy.deepcopy(G).eval() | |
| M_kwargs = dnnlib.EasyDict(class_name='training.networks.ConceptMaskNetwork', c_dim=training_set.label_dim, i_dim=G_kwargs.mapping_kwargs.i_dim) | |
| M = dnnlib.util.construct_class_by_name(**M_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module | |
| M_ema = copy.deepcopy(M).eval() | |
| # Resume from existing pickle. | |
| if (resume_pkl is not None) and (rank == 0): | |
| print(f'Resuming from "{resume_pkl}"') | |
| with dnnlib.util.open_url(resume_pkl) as f: | |
| resume_data = legacy.load_network_pkl(f) | |
| for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('M', M), ('M_ema', M_ema)]: | |
| misc.copy_params_and_buffers(resume_data[name], module, require_all=False) | |
| # Print network summary tables. | |
| if rank == 0: | |
| z = torch.empty([batch_gpu, G.z_dim], device=device) | |
| c = torch.empty([batch_gpu, G.c_dim], device=device) | |
| m = torch.empty([batch_gpu, G_kwargs.mapping_kwargs.i_dim], device=device) | |
| img = misc.print_module_summary(G, [z, m]) | |
| misc.print_module_summary(D, [img, c]) | |
| # Setup augmentation. | |
| if rank == 0: | |
| print('Setting up augmentation...') | |
| augment_pipe = None | |
| ada_stats = None | |
| if (augment_kwargs is not None) and (augment_p > 0 or ada_target is not None): | |
| augment_pipe = dnnlib.util.construct_class_by_name(**augment_kwargs).train().requires_grad_(False).to(device) # subclass of torch.nn.Module | |
| augment_pipe.p.copy_(torch.as_tensor(augment_p)) | |
| if ada_target is not None: | |
| ada_stats = training_stats.Collector(regex='Loss/signs/real') | |
| # Distribute across GPUs. | |
| if rank == 0: | |
| print(f'Distributing across {num_gpus} GPUs...') | |
| ddp_modules = dict() | |
| for name, module in [('G_mapping', G.mapping), ('G_synthesis', G.synthesis), ('D', D), (None, G_ema), ('augment_pipe', augment_pipe), | |
| ('M', M), (None, M_ema) | |
| ]: | |
| if (num_gpus > 1) and (module is not None) and len(list(module.parameters())) != 0: | |
| module.requires_grad_(True) | |
| module = torch.nn.parallel.DistributedDataParallel(module, device_ids=[device], broadcast_buffers=False) | |
| module.requires_grad_(False) | |
| if name is not None: | |
| ddp_modules[name] = module | |
| # Setup training phases. | |
| if rank == 0: | |
| print('Setting up training phases...') | |
| loss = dnnlib.util.construct_class_by_name(device=device, **ddp_modules, **loss_kwargs) # subclass of training.loss.Loss | |
| phases = [] | |
| for name, module, opt_kwargs, reg_interval in [('G', G, G_opt_kwargs, G_reg_interval), ('D', D, D_opt_kwargs, D_reg_interval)]: | |
| if reg_interval is None: | |
| opt = dnnlib.util.construct_class_by_name(params=module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer | |
| phases += [dnnlib.EasyDict(name=name+'both', module=module, opt=opt, interval=1)] | |
| else: # Lazy regularization. | |
| mb_ratio = reg_interval / (reg_interval + 1) | |
| opt_kwargs = dnnlib.EasyDict(opt_kwargs) | |
| opt_kwargs.lr = opt_kwargs.lr * mb_ratio | |
| opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] | |
| opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer | |
| phases += [dnnlib.EasyDict(name=name+'main', module=module, opt=opt, interval=1)] | |
| if name == 'G' and g_iter>0: | |
| phases += ([dnnlib.EasyDict(name=name + 'main', module=module, opt=opt, interval=1)] * g_iter) | |
| phases += [dnnlib.EasyDict(name=name+'reg', module=module, opt=opt, interval=reg_interval)] | |
| for name, module, opt_kwargs, reg_interval in [('M', M, G_opt_kwargs, G_reg_interval)]: | |
| mb_ratio = reg_interval / (reg_interval + 1) | |
| opt_kwargs = dnnlib.EasyDict(opt_kwargs) | |
| opt_kwargs.lr = opt_kwargs.lr * mb_ratio | |
| opt_kwargs.betas = [beta ** mb_ratio for beta in opt_kwargs.betas] | |
| #M_opt = dnnlib.util.construct_class_by_name(module.parameters(), **opt_kwargs) # subclass of torch.optim.Optimizer | |
| #M_opt = torch.optim.SGD(module.parameters(), lr=0.01, momentum=0.9) | |
| print(opt_kwargs.betas, ' >>>>>>>> opt kwargs ssss') | |
| M_opt = torch.optim.AdamW(module.parameters(), lr=opt_kwargs.lr, betas=(0.9, 0.999), eps=opt_kwargs.eps, | |
| weight_decay=0.01, amsgrad=False) | |
| for phase in phases: | |
| phase.start_event = None | |
| phase.end_event = None | |
| if rank == 0: | |
| phase.start_event = torch.cuda.Event(enable_timing=True) | |
| phase.end_event = torch.cuda.Event(enable_timing=True) | |
| # Export sample images. | |
| grid_size = None | |
| grid_z = None | |
| grid_c = None | |
| if rank == 0: | |
| print('Exporting sample images...') | |
| grid_size, images, labels, num_domains = setup_snapshot_image_grid(training_set=training_set) | |
| save_image_grid(images, os.path.join(run_dir, 'reals.jpg'), drange=[0,255], grid_size=grid_size) | |
| if labels.shape[1] > 0: | |
| grid_z = [] | |
| for i in range(grid_size[1]//num_domains): | |
| random_z = (torch.randn(grid_size[0], G.z_dim, device=device)) | |
| for j in range(num_domains): | |
| grid_z.append(random_z) | |
| grid_z = torch.cat(grid_z, 0).split(batch_gpu) | |
| else: | |
| grid_z = torch.randn([labels.shape[0], G.z_dim], device=device).split(batch_gpu) | |
| grid_c = torch.from_numpy(labels).to(device) | |
| grid_c = grid_c.split(batch_gpu) | |
| images = torch.cat([G_ema(z=z, c=M_ema(c), noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]).numpy() | |
| save_image_grid(images, os.path.join(run_dir, 'fakes_init.jpg'), drange=[-1,1], grid_size=grid_size) | |
| # Initialize logs. | |
| if rank == 0: | |
| print('Initializing logs...') | |
| stats_collector = training_stats.Collector(regex='.*') | |
| stats_metrics = dict() | |
| stats_jsonl = None | |
| stats_tfevents = None | |
| if rank == 0: | |
| stats_jsonl = open(os.path.join(run_dir, 'stats.jsonl'), 'wt') | |
| try: | |
| import torch.utils.tensorboard as tensorboard | |
| stats_tfevents = tensorboard.SummaryWriter(run_dir) | |
| except ImportError as err: | |
| print('Skipping tfevents export:', err) | |
| # Train. | |
| if rank == 0: | |
| print(f'Training for {total_kimg} kimg...') | |
| print() | |
| cur_nimg = 0 | |
| cur_tick = 0 | |
| tick_start_nimg = cur_nimg | |
| tick_start_time = time.time() | |
| maintenance_time = tick_start_time - start_time | |
| init_temperature = 1.0 | |
| min_temperature = 0.5 | |
| batch_idx = 0 | |
| if progress_fn is not None: | |
| progress_fn(0, total_kimg) | |
| names = ['Red 0', 'Red 1', 'Green 0', 'Green 1', 'Green 2', 'Green 3', 'Green 4', 'Green 5', 'Green 6', 'Green 7', | |
| 'Green 8', 'Green 9', 'Red 2', 'Blue 0', 'Blue 1', 'Blue 2', 'Blue 3', 'Blue 4', 'Blue 5', 'Blue 6', 'Blue 7', 'Blue 8', 'Blue 9', | |
| 'Red 3', 'Red 4', 'Red 5', 'Red 6', 'Red 7', 'Red 8', 'Red 9' | |
| ] | |
| if G.mapping.c_dim == 30: | |
| names = [ | |
| 'Blue 0', 'Blue 1', 'Blue 2', 'Blue 3', 'Blue 4', 'Blue 5', 'Blue 6', 'Blue 7', 'Blue 8', 'Blue 9', | |
| 'Green 0', 'Green 1', 'Green 2', 'Green 3', 'Green 4', 'Green 5', 'Green 6', 'Green 7', 'Green 8', 'Green 9', | |
| 'Red 0', 'Red 1', 'Red 2','Red 3', 'Red 4', 'Red 5', 'Red 6', 'Red 7', 'Red 8', 'Red 9' | |
| ] | |
| elif G.mapping.c_dim == 8: | |
| names = [ | |
| 'Bald NoSmile Male', 'Bald Smile Male', 'Black NoSmile Female', 'Black NoSmile Male', 'Black Smile Female', 'Black Smile Male', | |
| 'Blond NoSmile Female', 'Blond Smile Female' | |
| ] | |
| #names = ['Green Apple', 'Green Banana', 'Green Pear', 'Red Apple', 'Red Pear', 'Red Strawberry', 'Yellow Banana', 'Yellow Pineapple', 'Yellow StarFruit'] | |
| #names = ['Green Apple', 'Green Banana', 'Green Pear', 'Red Apple', 'Red Pear', 'Red Strawberry', 'Yellow Banana', 'Yellow Pineapple', 'Yellow StarFruit'] | |
| #names = ['Yellow 1', 'Purple 1', 'Red 1', 'Yellow 2', 'White 1', 'White 2', 'Red 2', 'Purple 2'] | |
| version_history_checker = VectorHistoryChecker(G.mapping.c_dim, G.mapping.i_dim, 3) | |
| column_history_cheker = ColumnHistoryChecker(G.mapping.c_dim, G.mapping.i_dim, 3) | |
| binary_mask_checker = SparsestVector() | |
| use_best_binary = 10 | |
| while True: | |
| ready = False | |
| cur_kimg = cur_nimg / 1000.0 | |
| should_restart = (cur_tick % 40 ==0) | |
| if cur_tick<=5: | |
| cur_lambda_rowvar = lambda_rowvar | |
| cur_lambda_colvar = 0 | |
| cur_lambda_sparse = lambda_sparse | |
| cur_entropy_thr = 0.6 | |
| cur_lambda_equal = 0 | |
| cur_lambda_entropy = lambda_entropy | |
| else: | |
| cur_lambda_rowvar = 0 | |
| cur_lambda_colvar = lambda_colvar | |
| cur_lambda_sparse = lambda_sparse | |
| cur_entropy_thr = 0.9 | |
| cur_lambda_equal = lambda_equal | |
| cur_lambda_entropy = lambda_entropy | |
| cur_lambda_ortho = lambda_ortho | |
| cur_temperature = 1. | |
| # Fetch training data. | |
| with torch.autograd.profiler.record_function('data_fetch'): | |
| phase_real_img, phase_real_c = next(training_set_iterator) | |
| phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu) | |
| phase_real_c = phase_real_c.to(device).split(batch_gpu) | |
| all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device) | |
| all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)] | |
| all_gen_c = [training_set.get_label(np.random.randint(len(training_set))) for _ in range(len(phases) * batch_size)] | |
| """ | |
| all_gen_c = [] | |
| for ta in tmp_all_gen_c: | |
| all_gen_c.append(F.one_hot(torch.randint(0, 30, (1,)), num_classes=30).float().to(device).squeeze().cpu().numpy()) | |
| tmp_all_gen_c = torch.from_numpy(np.stack(tmp_all_gen_c)).to(device) | |
| print(all_gen_c.size(), ' >>>>>>>>>>>>>>>>> all genc ', tmp_all_gen_c.size(), ' >>>>>>>>>>>>>>>>> tmp all genc ') | |
| """ | |
| all_gen_c = torch.from_numpy(np.stack(all_gen_c)).pin_memory().to(device) | |
| all_gen_c = [phase_gen_c.split(batch_gpu) for phase_gen_c in all_gen_c.split(batch_size)] | |
| loss_dict = {} | |
| # Execute training phases. | |
| gmain_count = 0 | |
| for phase, phase_gen_z, phase_gen_c in zip(phases, all_gen_z, all_gen_c): | |
| if batch_idx % phase.interval != 0: | |
| continue | |
| if phase.name == 'Gmain': | |
| gmain_count += 1 | |
| only1G = ((cur_tick>use_best_binary) and (gmain_count>1) and (phase.name == 'Gmain')) | |
| if only1G: | |
| continue | |
| # Initialize gradient accumulation. | |
| if phase.start_event is not None: | |
| phase.start_event.record(torch.cuda.current_stream(device)) | |
| phase.opt.zero_grad(set_to_none=True) | |
| phase.module.requires_grad_(True) | |
| M_opt.zero_grad(set_to_none=True) | |
| if phase.name == 'Gmain': | |
| M.requires_grad_(True) | |
| # Accumulate gradients over multiple rounds. | |
| for round_idx, (real_img, real_c, gen_z, gen_c) in enumerate(zip(phase_real_img, phase_real_c, phase_gen_z, phase_gen_c)): | |
| sync = (round_idx == batch_size // (batch_gpu * num_gpus) - 1) | |
| gain = phase.interval | |
| tmp_loss_dict = loss.accumulate_gradients(phase=phase.name, real_img=real_img, real_c=real_c, gen_z=gen_z, gen_c=gen_c, sync=sync, gain=gain, | |
| lambda_sparse=cur_lambda_sparse, lambda_entropy=cur_lambda_entropy, lambda_ortho=cur_lambda_ortho, lambda_path=lambda_path, | |
| lambda_epsilon=lambda_epsilon, lambda_colvar=cur_lambda_colvar, lambda_rowvar=cur_lambda_rowvar, | |
| lambda_equal=cur_lambda_equal, temperature=cur_temperature, entropy_thr=cur_entropy_thr, | |
| ) | |
| loss_dict.update(tmp_loss_dict) | |
| # Update weights. | |
| phase.module.requires_grad_(False) | |
| M.requires_grad_(False) | |
| with torch.autograd.profiler.record_function(phase.name + '_opt'): | |
| for param in phase.module.parameters(): | |
| if param.grad is not None: | |
| misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) | |
| phase.opt.step() | |
| for param in M.parameters(): | |
| if param.grad is not None: | |
| misc.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad) | |
| M_opt.step() | |
| if phase.end_event is not None: | |
| phase.end_event.record(torch.cuda.current_stream(device)) | |
| # Update G_ema. | |
| with torch.autograd.profiler.record_function('Gema'): | |
| ema_nimg = ema_kimg * 1000 | |
| if ema_rampup is not None: | |
| ema_nimg = min(ema_nimg, cur_nimg * ema_rampup) | |
| ema_beta = 0.5 ** (batch_size / max(ema_nimg, 1e-8)) | |
| for p_ema, p in zip(G_ema.parameters(), G.parameters()): | |
| p_ema.copy_(p.lerp(p_ema, ema_beta)) | |
| for b_ema, b in zip(G_ema.buffers(), G.buffers()): | |
| b_ema.copy_(b) | |
| #ema_beta = 0.9 | |
| for p_ema, p in zip(M_ema.parameters(), M.parameters()): | |
| p_ema.copy_(p.lerp(p_ema, ema_beta)) | |
| for b_ema, b in zip(M_ema.buffers(), M.buffers()): | |
| b_ema.copy_(b) | |
| # Update state. | |
| cur_nimg += batch_size | |
| batch_idx += 1 | |
| # Execute ADA heuristic. | |
| if (ada_stats is not None) and (batch_idx % ada_interval == 0): | |
| ada_stats.update() | |
| adjust = np.sign(ada_stats['Loss/signs/real'] - ada_target) * (batch_size * ada_interval) / (ada_kimg * 1000) | |
| augment_pipe.p.copy_((augment_pipe.p + adjust).max(misc.constant(0, device=device))) | |
| # Perform maintenance tasks once per tick. | |
| done = (cur_nimg >= total_kimg * 1000) | |
| if (not done) and (cur_tick != 0) and (cur_nimg < tick_start_nimg + kimg_per_tick * 1000): | |
| continue | |
| # Print status line, accumulating the same information in stats_collector. | |
| tick_end_time = time.time() | |
| fields = [] | |
| fields += [f"tick {training_stats.report0('Progress/tick', cur_tick):<5d}"] | |
| fields += [f"kimg {training_stats.report0('Progress/kimg', cur_nimg / 1e3):<8.1f}"] | |
| fields += [f"time {dnnlib.util.format_time(training_stats.report0('Timing/total_sec', tick_end_time - start_time)):<12s}"] | |
| fields += [f"sec/tick {training_stats.report0('Timing/sec_per_tick', tick_end_time - tick_start_time):<7.1f}"] | |
| fields += [f"sec/kimg {training_stats.report0('Timing/sec_per_kimg', (tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg) * 1e3):<7.2f}"] | |
| #fields += [f"maintenance {training_stats.report0('Timing/maintenance_sec', maintenance_time):<6.1f}"] | |
| #fields += [f"cpumem {training_stats.report0('Resources/cpu_mem_gb', psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}"] | |
| #fields += [f"gpumem {training_stats.report0('Resources/peak_gpu_mem_gb', torch.cuda.max_memory_allocated(device) / 2**30):<6.2f}"] | |
| fields += [f"sparse {loss_dict['loss_sparse']:.3f}"] | |
| fields += [f"entropy {loss_dict['loss_entropy']:.3f}"] | |
| fields += [f"path {loss_dict['loss_path']:.3f}"] | |
| fields += [f"equal {loss_dict['loss_equal']:.3f}"] | |
| fields += [f"rowvar {loss_dict['loss_rowvar']:.3f}"] | |
| fields += [f"colvar {loss_dict['loss_colvar']:.3f}"] | |
| fields += [f"lambda_sparse {cur_lambda_sparse:.3f}"] | |
| fields += [f"lambda_entropy {cur_lambda_entropy:.3f}"] | |
| fields += [f"lambda_rowvar {cur_lambda_rowvar:.3f}"] | |
| fields += [f"lambda_colvar {cur_lambda_colvar:.3f}"] | |
| fields += [f"lambda_path {lambda_path:.3f}"] | |
| fields += [f"lambda_equal {lambda_equal:.3f}"] | |
| fields += [f"thr {cur_entropy_thr:.3f}"] | |
| torch.cuda.reset_peak_memory_stats() | |
| #fields += [f"augment {training_stats.report0('Progress/augment', float(augment_pipe.p.cpu()) if augment_pipe is not None else 0):.3f}"] | |
| training_stats.report0('Timing/total_hours', (tick_end_time - start_time) / (60 * 60)) | |
| training_stats.report0('Timing/total_days', (tick_end_time - start_time) / (24 * 60 * 60)) | |
| if rank == 0: | |
| print(' '.join(fields)) | |
| # Check for abort. | |
| if (not done) and (abort_fn is not None) and abort_fn(): | |
| done = True | |
| if rank == 0: | |
| print() | |
| print('Aborting...') | |
| # Save image snapshot. | |
| if (rank == 0) and (image_snapshot_ticks is not None) and (done or cur_tick % image_snapshot_ticks == 0): | |
| wss = torch.cat([G_ema.mapping(z,M_ema(c)) for z,c in zip(grid_z, grid_c)]) | |
| images = torch.cat([G_ema(z=z, c=M_ema(c), noise_mode='const').cpu() for z, c in zip(grid_z, grid_c)]) | |
| def normalize_2nd_moment(x, dim=1, eps=1e-8): | |
| return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() | |
| cs = [] | |
| for c in grid_c: | |
| cs.append(c.argmax(dim=1)) | |
| cs = torch.cat(cs, 0).view(G.mapping.c_dim, -1) | |
| tmp_imgs = images.reshape(G.mapping.c_dim, -1, images.shape[1], images.shape[2], images.shape[3]) | |
| images = images.numpy() | |
| wss = wss.reshape(G.mapping.c_dim, -1, wss.shape[1], wss.shape[2]) | |
| print(cs.size(), tmp_imgs.shape, wss.shape, ' >>>>>cs size tmp_imgs size <<<<<<<<') | |
| save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.jpg'), drange=[-1,1], grid_size=grid_size) | |
| try: | |
| print(G_ema.mapping.importance0, G_ema.mapping.importance1) | |
| except: | |
| pass | |
| all_masks = [] | |
| with torch.no_grad(): | |
| cin = torch.arange(G.mapping.c_dim, device=device) | |
| cin = F.one_hot(cin, num_classes=G.mapping.c_dim).float() | |
| all_logit = M(cin) | |
| all_soft_mask = ((all_logit)) | |
| all_hard_mask = (all_soft_mask > 0.5).float() | |
| for i in range(G.mapping.c_dim): | |
| print('%40s' % names[i], ' ', all_soft_mask[i].cpu().numpy()) | |
| for i in range(G.mapping.c_dim): | |
| print('%40s' % names[i], ' ', all_hard_mask[i].cpu().numpy().astype(np.uint8)) | |
| all_logit = M_ema(cin) | |
| all_soft_mask = ((all_logit)) | |
| all_hard_mask = (all_soft_mask > 0.5).float() | |
| for i in range(G.mapping.c_dim): | |
| print('%40s' % names[i], ' ', all_soft_mask[i].cpu().numpy()) | |
| for i in range(G.mapping.c_dim): | |
| print('%40s' % names[i], ' ', all_hard_mask[i].cpu().numpy().astype(np.uint8)) | |
| dscores = [] | |
| dhard_masks = all_hard_mask.clone() | |
| dsoft_masks = all_soft_mask.clone() | |
| for i in range(G.mapping.c_dim): | |
| cur_imgs = tmp_imgs[i].to(device) | |
| cur_c = F.one_hot(torch.tensor([i]*cur_imgs.size(0), device=device), num_classes=G.mapping.c_dim).float().to(device) | |
| d_out = D(cur_imgs, cur_c) | |
| d_out = F.softplus(d_out) | |
| print('%40s mean: %.2f min: %.2f max: %.2f' % (names[i], d_out.mean().item(), d_out.min().item(), d_out.max().item())) | |
| dscores.append(d_out.min().item()) | |
| #eval_mask = M(cin, eval=True) | |
| #for i in range(G.mapping.c_dim): | |
| # print('%10s' % names[i], ' ', eval_mask[i].cpu().numpy().astype(np.uint8)) | |
| def normalize_2nd_moment(x, dim=1, eps=1e-8): | |
| return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() | |
| def get_onehot(y): | |
| shape = y.size() | |
| _, ind = y.max(dim=-1) | |
| y_hard = torch.zeros_like(y).view(-1, shape[-1]) | |
| y_hard.scatter_(1, ind.view(-1, 1), 1) | |
| y_hard = y_hard.view(*shape) | |
| return y_hard | |
| def no_same_rows(x): | |
| has = False | |
| for i in range(len(x)): | |
| for j in range(i+1, len(x)): | |
| if torch.allclose(x[i], x[j]): | |
| has = True | |
| return not has | |
| def has_enough_concepts(x): | |
| has = True | |
| for i in range(len(x)): | |
| if torch.sum(x[i])<=1: | |
| has = False | |
| return has | |
| if no_same_rows(dhard_masks) and has_enough_concepts(dhard_masks): | |
| print('') | |
| print('>>>>>>>>>>>>> This version can be used <<<<<<<<<<<<<<') | |
| print('') | |
| ready = True | |
| binary_mask_checker.add(dhard_masks) | |
| try: | |
| best_mask = binary_mask_checker.check() | |
| for i in range(G.mapping.c_dim): | |
| print('%40s' % names[i], ' ', best_mask[i].cpu().numpy().astype(np.uint8), ' best') | |
| except: | |
| pass | |
| masks = all_soft_mask | |
| hard_masks = all_hard_mask | |
| for i in range(G.mapping.i_dim): | |
| cur_i_imgs = [] | |
| sorted_index = np.argsort(masks[:, i].cpu().numpy(), axis=0)[::-1] | |
| for j in sorted_index: | |
| if hard_masks[j, i] == 1: | |
| cur_i_imgs.append(tmp_imgs[j]) | |
| if len(cur_i_imgs) > 0: | |
| cur_i_imgs = torch.cat(cur_i_imgs, 0) | |
| save_image(cur_i_imgs, os.path.join(run_dir, f'concept_{cur_nimg // 1000:06d}_{i}.jpg'), | |
| nrow=grid_size[0], normalize=True, range=(-1, 1)) | |
| if True: | |
| for i in range(G.mapping.c_dim): | |
| if False: | |
| M.param_net.data[i] += -1e9*(dsoft_masks[i]<0.05) | |
| M_ema.param_net.data[i] += -1e9*(dsoft_masks[i]<0.05) | |
| M.use_param[i] = (dsoft_masks[i]<0.05).float() | |
| M_ema.use_param[i] = (dsoft_masks[i]<0.05).float() | |
| #print(dscores[i], names[i], ' >>>>>>. what fuck ', M.use_param.view(-1), M.param_net[i]) | |
| #topk = torch.topk(torch.tensor(dscores), k=5)[1] | |
| consistency = version_history_checker.check_history(dhard_masks) | |
| version_history_checker.update_history(dhard_masks) | |
| for i in range(G.mapping.c_dim): | |
| all_sum = torch.sum(dhard_masks, dim=1) | |
| target = torch.mode(all_sum)[0] | |
| cur_sum = all_sum[i] | |
| set_thr = 1.0 | |
| cond1 = (dscores[i]>=set_thr) | |
| crit = (cur_sum>1 and cur_sum<=target) | |
| #cond2 = (dscores[i]>=0.6 and cur_sum>1 and cur_sum<=target and (i in list(topk.cpu()))) | |
| cond3 = consistency[i] | |
| should_use=True | |
| for j in range(G.mapping.c_dim): | |
| if dscores[j]> dscores[i] and torch.sum(torch.abs(dhard_masks[i]-dhard_masks[j]))==0 and j!=i: | |
| should_use = False | |
| if (cond1) and should_use and crit: | |
| #M.param_net.data[i] = 1e9*dhard_masks[i] | |
| #M.param_net.data[i] += -1e9*(1-dhard_masks[i]) | |
| M.target_value[i] = dhard_masks[i] | |
| M.use_param[i] = torch.ones_like(M.use_param[i]) | |
| #M_ema.param_net.data[i] = 1e9*dhard_masks[i] | |
| #M_ema.param_net.data[i] += -1e9*(1-dhard_masks[i]) | |
| M_ema.target_value[i] = dhard_masks[i] | |
| M_ema.use_param[i] = torch.ones_like(M.use_param[i]) | |
| print('>>>>>> replace classss ', names[i], ' ', dscores[i], ' ', M.target_value[i], ' << consistency ', consistency[i]) | |
| column_consistency = column_history_cheker.check_history(dhard_masks) | |
| column_history_cheker.update_history(dhard_masks) | |
| for j in range(G.mapping.i_dim): | |
| cur_soft = dsoft_masks[:,j] | |
| cur_hard = dhard_masks[:,j] | |
| act = cur_soft[cur_hard==1] | |
| deact = cur_soft[cur_hard==0] | |
| cur_sum = torch.sum(cur_hard) | |
| if (act.mean()>0.9 and act.min()>0.6 and cur_sum>1 and cur_tick==5): | |
| #M.param_net.data[:,j] = cur_hard*19 | |
| #M.param_net.data[:,j] += -1e19*(1-cur_hard) | |
| M.use_param[:,j] = torch.ones_like(M.use_param[:,j]) | |
| M.target_value[:,j] = cur_hard | |
| #M_ema.param_net.data[:,j] = cur_hard | |
| #M_ema.param_net.data[:,j] += -1e19*(1-cur_hard) | |
| M_ema.target_value[:,j] = cur_hard | |
| M_ema.use_param[:,j] = torch.ones_like(M.use_param[:,j]) | |
| print('>>>>> replace columns ', j, ' ', M.target_value[:,j].view(-1), ' ', column_consistency[j]) | |
| if cur_tick == use_best_binary: | |
| best_mask = binary_mask_checker.check() | |
| if best_mask is not None: | |
| M.use_param = torch.ones_like(M.use_param) | |
| M.target_value = best_mask | |
| M_ema.use_param = torch.ones_like(M.use_param) | |
| M_ema.target_value = best_mask | |
| if (cur_tick % 5 ==0 and cur_tick>0) or cur_tick == use_best_binary: | |
| for param in M.parameters(): | |
| torch.distributed.broadcast(param.data, 0) | |
| torch.distributed.broadcast(M.use_param, 0) | |
| torch.distributed.broadcast(M_ema.use_param, 0) | |
| torch.distributed.broadcast(M.target_value, 0) | |
| torch.distributed.broadcast(M_ema.target_value, 0) | |
| for param in M_ema.parameters(): | |
| torch.distributed.broadcast(param.data, 0) | |
| torch.distributed.barrier() | |
| #print(M.use_param, ' >>>>>>> m M use_oaramssss bripdcatss ') | |
| # Save network snapshot. | |
| snapshot_pkl = None | |
| snapshot_data = None | |
| if (network_snapshot_ticks is not None) and (done or cur_tick % network_snapshot_ticks == 0) and cur_tick>0: | |
| snapshot_data = dict(training_set_kwargs=dict(training_set_kwargs)) | |
| for name, module in [('G', G), ('D', D), ('G_ema', G_ema), ('augment_pipe', augment_pipe), ('M', M), ('M_ema', M_ema)]: | |
| if module is not None: | |
| if num_gpus > 1: | |
| misc.check_ddp_consistency(module, ignore_regex=r'.*\.w_avg') | |
| module = copy.deepcopy(module).eval().requires_grad_(False).cpu() | |
| snapshot_data[name] = module | |
| del module # conserve memory | |
| snapshot_pkl = os.path.join(run_dir, f'network-snapshot-{cur_nimg//1000:06d}.pkl') | |
| if rank == 0: | |
| #pass | |
| with open(snapshot_pkl, 'wb') as f: | |
| pickle.dump(snapshot_data, f) | |
| # Evaluate metrics. | |
| if (snapshot_data is not None) and (len(metrics) > 0): | |
| if rank == 0: | |
| print('Evaluating metrics...') | |
| for metric in metrics: | |
| result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'], M=snapshot_data['M_ema'], | |
| dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device) | |
| if rank == 0: | |
| metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl) | |
| stats_metrics.update(result_dict.results) | |
| del snapshot_data # conserve memory | |
| # Collect statistics. | |
| for phase in phases: | |
| value = [] | |
| if (phase.start_event is not None) and (phase.end_event is not None): | |
| phase.end_event.synchronize() | |
| value = phase.start_event.elapsed_time(phase.end_event) | |
| training_stats.report0('Timing/' + phase.name, value) | |
| stats_collector.update() | |
| stats_dict = stats_collector.as_dict() | |
| # Update logs. | |
| timestamp = time.time() | |
| if stats_jsonl is not None: | |
| fields = dict(stats_dict, timestamp=timestamp) | |
| stats_jsonl.write(json.dumps(fields) + '\n') | |
| stats_jsonl.flush() | |
| if stats_tfevents is not None: | |
| global_step = int(cur_nimg / 1e3) | |
| walltime = timestamp - start_time | |
| for name, value in stats_dict.items(): | |
| stats_tfevents.add_scalar(name, value.mean, global_step=global_step, walltime=walltime) | |
| for name, value in stats_metrics.items(): | |
| stats_tfevents.add_scalar(f'Metrics/{name}', value, global_step=global_step, walltime=walltime) | |
| stats_tfevents.flush() | |
| if progress_fn is not None: | |
| progress_fn(cur_nimg // 1000, total_kimg) | |
| # Update state. | |
| if False and cur_tick%5==0: | |
| for paramgroup in M_opt.param_groups: | |
| paramgroup['lr'] = paramgroup['lr'] * 0.1 | |
| print('>>>>>>>LR decay <<<<<<< %.7f' % paramgroup['lr']) | |
| cur_tick += 1 | |
| tick_start_nimg = cur_nimg | |
| tick_start_time = time.time() | |
| maintenance_time = tick_start_time - tick_end_time | |
| if done: | |
| break | |
| # Done. | |
| if rank == 0: | |
| print() | |
| print('Exiting...') | |
| #---------------------------------------------------------------------------- | |