import os import copy from pathlib import Path import torch import torch.distributed as dist import numpy as np import math import socket # source: https://github.com/LTH14/mar/blob/main/util/misc.py def prep_torch(): cpu_cores = get_cpu_cores() torch.set_num_threads(cpu_cores) # intra-op threads (e.g., matrix ops) torch.set_num_interop_threads(cpu_cores) # inter-op parallelism os.environ["OMP_NUM_THREADS"] = str(cpu_cores) os.environ["MKL_NUM_THREADS"] = str(cpu_cores) os.environ["OPENBLAS_NUM_THREADS"] = str(cpu_cores) def get_cpu_cores(): hostname = socket.gethostname() if "bridges2" in hostname: return int(os.environ["SLURM_JOB_CPUS_PER_NODE"]) else: try: with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us", "r") as f: quota = int(f.read().strip()) with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us", "r") as f: period = int(f.read().strip()) if quota > 0: return max(1, quota // period) except Exception as e: return os.cpu_count() def setup_distributed(): dist.init_process_group(backend='nccl') # Get the rank of the current process rank = int(os.environ.get('RANK')) world_size = int(os.environ.get('WORLD_SIZE')) local_rank = int(os.environ.get('LOCAL_RANK')) torch.cuda.set_device(local_rank) return rank, world_size, local_rank def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def save_model(args, epoch, model, optimizer, ema_params=None, epoch_name=None): if epoch_name is None: epoch_name = str(epoch) output_dir = Path(args.logdir) checkpoint_path = output_dir / ('checkpoint-%s.pth' % epoch_name) if ema_params is not None: ema_state_dict = copy.deepcopy(model.state_dict()) for i, (name, _value) in enumerate(model.named_parameters()): assert name in ema_state_dict ema_state_dict[name] = ema_params[i] else: ema_state_dict = None to_save = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'args': args, 'model_ema': ema_state_dict, } save_on_master(to_save, checkpoint_path) def adjust_learning_rate(optimizer, epoch, args): """Decay the learning rate with half-cycle cosine after warmup""" if epoch < args.warmup_epochs: lr = args.lr * epoch / args.warmup_epochs else: lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.n_epochs - args.warmup_epochs))) for param_group in optimizer.param_groups: if "lr_scale" in param_group: param_group["lr"] = lr * param_group["lr_scale"] else: param_group["lr"] = lr return lr def add_weight_decay(model, weight_decay=1e-5, skip_list=()): decay = [] no_decay = [] for name, param in model.named_parameters(): if not param.requires_grad: continue # frozen weights if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list or 'diffloss' in name: no_decay.append(param) # no weight decay on bias, norm and diffloss else: decay.append(param) return [ {'params': no_decay, 'weight_decay': 0.}, {'params': decay, 'weight_decay': weight_decay}]