Spaces:
Paused
Paused
| import json | |
| import math | |
| import os | |
| import random | |
| import subprocess | |
| import sys | |
| import time | |
| from collections import OrderedDict, deque | |
| from typing import Optional, Union | |
| import numpy as np | |
| import torch | |
| from tap import Tap | |
| import infinity.utils.dist as dist | |
| class Args(Tap): | |
| local_out_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # directory for save checkpoints | |
| data_path: str = '' # dataset | |
| bed: str = '' # bed directory for copy checkpoints apart from local_out_path | |
| vae_ckpt: str = '' # VAE ckpt | |
| exp_name: str = '' # experiment name | |
| ds: str = 'oi' # only used in GPT training::load_viz_data & FID benchmark | |
| model: str = '' # for VAE training, 'b' or any other for GPT training | |
| short_cap_prob: float = 0.2 # prob for training with short captions | |
| project_name: str = 'Infinity' # name of wandb project | |
| tf32: bool = True # whether to use TensorFloat32 | |
| auto_resume: bool = True # whether to automatically resume from the last checkpoint found in args.bed | |
| rush_resume: str = '' # pretrained infinity checkpoint | |
| nowd: int = 1 # whether to disable weight decay on sparse params (like class token) | |
| enable_hybrid_shard: bool = False # whether to use hybrid FSDP | |
| inner_shard_degree: int = 1 # inner degree for FSDP | |
| zero: int = 0 # ds zero | |
| buck: str = 'chunk' # =0 for using module-wise | |
| fsdp_orig: bool = True | |
| enable_checkpointing: str = None # checkpointing strategy: full-block, self-attn | |
| pad_to_multiplier: int = 1 # >1 for padding the seq len to a multiplier of this | |
| log_every_iter: bool = False | |
| checkpoint_type: str = 'torch' # checkpoint_type: torch, onmistore | |
| seed: int = None # 3407 | |
| rand: bool = True # actual seed = seed + (dist.get_rank()*512 if rand else 0) | |
| device: str = 'cpu' | |
| task_id: str = '2493513' | |
| trial_id: str = '7260554' | |
| robust_run_id: str = '00' | |
| ckpt_trials = [] | |
| real_trial_id: str = '7260552' | |
| chunk_nodes: int = None | |
| is_master_node: bool = None | |
| # dir | |
| log_txt_path: str = '' | |
| t5_path: str = '' # if not specified: automatically find from all bytenas | |
| online_t5: bool = True # whether to use online t5 or load local features | |
| # GPT | |
| sdpa_mem: bool = True # whether to use with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True) | |
| tfast: int = 0 # compile GPT | |
| model_alias: str = 'b' # [automatically set; don't specify this] | |
| rms: bool = False | |
| aln: float = 1e-3 # multiplier of ada_lin.w's initialization | |
| alng: float = -1 # multiplier of ada_lin.w[gamma channels]'s initialization, -1: the same as aln | |
| saln: bool = False # whether to use a shared adaln layer | |
| haln: bool = True # whether to use a specific adaln layer in head layer | |
| nm0: bool = False # norm before word proj linear | |
| tau: float = 1 # tau of self attention in GPT | |
| cos: bool = True # cosine attn as in swin v2 | |
| swi: bool = False # whether to use FFNSwiGLU, instead of vanilla FFN | |
| dp: float = -1 | |
| drop: float = 0.0 # GPT's dropout (VAE's is --vd) | |
| hd: int = 0 | |
| ca_gamma: float = -1 # >=0 for using layer-scale for cross attention | |
| diva: int = 1 # rescale_attn_fc_weights | |
| hd0: float = 0.02 # head.w *= hd0 | |
| dec: int = 1 # dec depth | |
| cum: int = 3 # cumulating fea map as GPT TF input, 0: not cum; 1: cum @ next hw, 2: cum @ final hw | |
| rwe: bool = False # random word emb | |
| tp: float = 0.0 # top-p | |
| tk: float = 0.0 # top-k | |
| tini: float = 0.02 # init parameters | |
| cfg: float = 0.1 # >0: classifier-free guidance, drop cond with prob cfg | |
| rand_uncond = False # whether to use random, unlearnable uncond embeding | |
| ema: float = 0.9999 # VAE's ema ratio, not VAR's. 0.9977844 == 0.5 ** (32 / (10 * 1000)) from gans, 0.9999 from SD | |
| tema: float = 0 # 0.9999 in DiffiT, DiT | |
| fp16: int = 0 # 1: fp16, 2: bf16, >2: fp16's max scaling multiplier todo: 记得让quantize相关的feature都强制fp32!另外residueal最好也是fp32(根据flash-attention)nn.Conv2d有一个参数是use_float16? | |
| fuse: bool = False # whether to use fused mlp | |
| fused_norm: bool = False # whether to use fused norm | |
| flash: bool = False # whether to use customized flash-attn kernel | |
| xen: bool = False # whether to use xentropy | |
| use_flex_attn: bool = False # whether to use flex_attn to speedup training | |
| stable: bool = False | |
| gblr: float = 1e-4 | |
| dblr: float = None # =gblr if is None | |
| tblr: float = 6e-4 | |
| glr: float = None | |
| dlr: float = None | |
| tlr: float = None # vqgan: 4e-5 | |
| gwd: float = 0.005 | |
| dwd: float = 0.0005 | |
| twd: float = 0.005 # vqgan: 0.01 | |
| gwde: float = 0 | |
| dwde: float = 0 | |
| twde: float = 0 | |
| ls: float = 0.0 # label smooth | |
| lz: float = 0.0 # z loss from PaLM = 1e-4 todo | |
| eq: int = 0 # equalized loss | |
| ep: int = 100 | |
| wp: float = 0 | |
| wp0: float = 0.005 | |
| wpe: float = 0.3 # 0.001, final cosine lr = wpe * peak lr | |
| sche: str = '' # cos, exp, lin | |
| log_freq: int = 50 # log frequency in the stdout | |
| gclip: float = 6. # <=0 for not grad clip VAE | |
| dclip: float = 6. # <=0 for not grad clip discriminator | |
| tclip: float = 2. # <=0 for not grad clip GPT; >100 for per-param clip (%= 100 automatically) | |
| cdec: bool = False # decay the grad clip thresholds of GPT and GPT's word embed | |
| opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5(比Adam学习率低四倍)和wd=0.8(比Adam高八倍);比如在小的 batch_size 时,Lion 的表现不如 AdamW | |
| ada: str = '' # adam's beta0 and beta1 for VAE or GPT, '0_0.99' from style-swin and magvit, '0.5_0.9' from VQGAN | |
| dada: str = '' # adam's beta0 and beta1 for discriminator | |
| oeps: float = 0 # adam's eps, pixart uses 1e-10 | |
| afuse: bool = True # fused adam | |
| # data | |
| pn: str = '' # pixel nums, choose from 0.06M, 0.25M, 1M | |
| scale_schedule: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) | |
| patch_size: int = None # [automatically set; don't specify this] = 2 ** (len(args.scale_schedule) - 1) | |
| resos: tuple = None # [automatically set; don't specify this] | |
| data_load_reso: int = None # [automatically set; don't specify this] | |
| workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader | |
| lbs: int = 0 # local batch size; if lbs != 0, bs will be ignored, and will be reset as round(args.lbs / args.ac) * dist.get_world_size() | |
| bs: int = 0 # global batch size; if lbs != 0, bs will be ignored | |
| batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size()) | |
| glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() | |
| ac: int = 1 # gradient accumulation | |
| r_accu: float = 1.0 # [automatically set; don't specify this] = 1 / args.ac | |
| norm_eps: float = 1e-6 # norm eps for infinity | |
| tlen: int = 512 # truncate text embedding to this length | |
| Ct5: int = 2048 # feature dimension of text encoder | |
| use_bit_label: int = 1 # pred bitwise labels or index-wise labels | |
| bitloss_type: str = 'mean' # mean or sum | |
| dynamic_resolution_across_gpus: int = 1 # allow dynamic resolution across gpus | |
| enable_dynamic_length_prompt: int = 0 # enable dynamic length prompt during training | |
| use_streaming_dataset: int = 0 # use streaming dataset | |
| iterable_data_buffersize: int = 90000 # streaming dataset buffer size | |
| save_model_iters_freq: int = 1000 # save model iter freq | |
| noise_apply_layers: int = -1 # Bitwise Self-Correction: apply noise to layers, -1 means not apply noise | |
| noise_apply_strength: float = -1 # Bitwise Self-Correction: apply noise strength, -1 means not apply noise | |
| noise_apply_requant: int = 1 # Bitwise Self-Correction: requant after apply noise | |
| rope2d_each_sa_layer: int = 0 # apply rope2d to each self-attention layer | |
| rope2d_normalized_by_hw: int = 1 # apply normalized rope2d | |
| use_fsdp_model_ema: int = 0 # use fsdp model ema | |
| add_lvl_embeding_only_first_block: int = 1 # apply lvl pe embedding only first block or each block | |
| reweight_loss_by_scale: int = 0 # reweight loss by scale | |
| always_training_scales: int = 100 # trunc training scales | |
| vae_type: int = 1 # here 16/32/64 is bsq vae of different quant bits | |
| fake_vae_input: bool = False # fake vae input for debug | |
| model_init_device: str = 'cuda' # model_init_device | |
| prefetch_factor: int = 2 # prefetch_factor for dataset | |
| apply_spatial_patchify: int = 0 # apply apply_spatial_patchify or not | |
| debug_bsc: int = 0 # save figs and set breakpoint for debug bsc and check input | |
| task_type: str = 't2i' # take type to t2i or t2v | |
| ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
| ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
| ############################ Attention! The following arguments and configurations are set automatically, you can skip reading the following part ############################### | |
| # would be automatically set in runtime | |
| branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
| commit_id: str = '' # subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] | |
| commit_msg: str = ''# (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] | |
| cmd: str = ' '.join(a.replace('--exp_name=', '').replace('--exp_name ', '') for a in sys.argv[7:]) # [automatically set; don't specify this] | |
| tag: str = 'UK' # [automatically set; don't specify this] | |
| acc_all: float = None # [automatically set; don't specify this] | |
| acc_real: float = None # [automatically set; don't specify this] | |
| acc_fake: float = None # [automatically set; don't specify this] | |
| last_Lnll: float = None # [automatically set; don't specify this] | |
| last_L1: float = None # [automatically set; don't specify this] | |
| last_Ld: float = None # [automatically set; don't specify this] | |
| last_wei_g: float = None # [automatically set; don't specify this] | |
| grad_boom: str = None # [automatically set; don't specify this] | |
| diff: float = None # [automatically set; don't specify this] | |
| diffs: str = '' # [automatically set; don't specify this] | |
| diffs_ema: str = None # [automatically set; don't specify this] | |
| ca_performance: str = '' # [automatically set; don't specify this] | |
| cur_phase: str = '' # [automatically set; don't specify this] | |
| cur_it: str = '' # [automatically set; don't specify this] | |
| cur_ep: str = '' # [automatically set; don't specify this] | |
| remain_time: str = '' # [automatically set; don't specify this] | |
| finish_time: str = '' # [automatically set; don't specify this] | |
| iter_speed: float = None # [automatically set; don't specify this] | |
| img_per_day: float = None # [automatically set; don't specify this] | |
| max_nvidia_smi: float = 0 # [automatically set; don't specify this] | |
| max_memory_allocated: float = None # [automatically set; don't specify this] | |
| max_memory_reserved: float = None # [automatically set; don't specify this] | |
| num_alloc_retries: int = None # [automatically set; don't specify this] | |
| MFU: float = None # [automatically set; don't specify this] | |
| HFU: float = None # [automatically set; don't specify this] | |
| # ================================================================================================================== | |
| # ======================== ignore these parts below since they are only for debug use ============================== | |
| # ================================================================================================================== | |
| dbg_modified: bool = False | |
| dbg_ks: bool = False | |
| dbg_ks_last = None | |
| dbg_ks_fp = None | |
| def dbg_ks_this_line(self, g_it: int): | |
| if self.dbg_ks: | |
| if self.dbg_ks_last is None: | |
| self.dbg_ks_last = deque(maxlen=6) | |
| from utils.misc import time_str | |
| self.dbg_ks_fp.seek(0) | |
| f_back = sys._getframe().f_back | |
| file_desc = f'{f_back.f_code.co_filename:24s}'[-24:] | |
| info = f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})' | |
| if g_it is not None: | |
| info += f' [g_it: {g_it}]' | |
| self.dbg_ks_last.append(info) | |
| self.dbg_ks_fp.write('\n'.join(self.dbg_ks_last) + '\n') | |
| self.dbg_ks_fp.flush() | |
| dbg: bool = 'KEVIN_LOCAL' in os.environ # only used when debug about unused param in DDP | |
| ks: bool = False | |
| nodata: bool = False # if True, will set nova=True as well | |
| nodata_tlen: int = 320 | |
| nova: bool = False # no val, no FID | |
| prof: int = 0 # profile | |
| prof_freq: int = 50 # profile | |
| tos_profiler_file_prefix: str = 'vgpt_default/' | |
| profall: int = 0 | |
| def is_vae_visualization_only(self) -> bool: | |
| return self.v_seed > 0 | |
| v_seed: int = 0 # v_seed != 0 means the visualization-only mode | |
| def is_gpt_visualization_only(self) -> bool: | |
| return self.g_seed > 0 | |
| g_seed: int = 0 # g_seed != 0 means the visualization-only mode | |
| # ================================================================================================================== | |
| # ======================== ignore these parts above since they are only for debug use ============================== | |
| # ================================================================================================================== | |
| def gpt_training(self): | |
| return len(self.model) > 0 | |
| def set_initial_seed(self, benchmark: bool): | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.benchmark = benchmark | |
| if self.seed is None: | |
| torch.backends.cudnn.deterministic = False | |
| else: | |
| seed = self.seed + (dist.get_rank()*512 if self.rand else 0) | |
| torch.backends.cudnn.deterministic = True | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation | |
| if self.seed is None: | |
| return None | |
| g = torch.Generator() | |
| g.manual_seed(self.seed + dist.get_rank()*512) | |
| return g | |
| def compile_model(self, m, fast): | |
| if fast == 0: | |
| return m | |
| return torch.compile(m, mode={ | |
| 1: 'reduce-overhead', | |
| 2: 'max-autotune', | |
| 3: 'default', | |
| }[fast]) if hasattr(torch, 'compile') else m | |
| def dump_log(self): | |
| if not dist.is_local_master(): | |
| return | |
| nd = {'is_master': dist.is_visualizer()} | |
| r_trial, trial = str(self.real_trial_id), str(self.trial_id) | |
| for k, v in { | |
| 'name': self.exp_name, 'tag': self.tag, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, | |
| 'Lnll': self.last_Lnll, 'L1': self.last_L1, | |
| 'Ld': self.last_Ld, | |
| 'acc': self.acc_all, 'acc_r': self.acc_real, 'acc_f': self.acc_fake, | |
| 'weiG': self.last_wei_g if (self.last_wei_g is None or math.isfinite(self.last_wei_g)) else -23333, | |
| 'grad': self.grad_boom, | |
| 'cur': self.cur_phase, 'cur_ep': self.cur_ep, 'cur_it': self.cur_it, | |
| 'rema': self.remain_time, 'fini': self.finish_time, 'last_upd': time.strftime("%Y-%m-%d %H:%M", time.localtime()), | |
| 'bsep': f'{self.glb_batch_size}/{self.ep}', | |
| 'G_lrwd': f'{self.glr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.gwd:g}', | |
| 'D_lrwd': f'{self.dlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.dwd:g}', | |
| 'T_lrwd': f'{self.tlr:.1e}'.replace('.0', '').replace('-0', '-').replace('+0', '+') + f'/{self.twd:g}', | |
| 'diff': self.diff, 'diffs': self.diffs, 'diffs_ema': self.diffs_ema if self.diffs_ema else None, | |
| 'opt': self.opt, | |
| 'is_master_node': self.is_master_node, | |
| }.items(): | |
| if hasattr(v, 'item'):v = v.item() | |
| if v is None or (isinstance(v, str) and len(v) == 0): continue | |
| nd[k] = v | |
| if r_trial == trial: | |
| nd.pop('trial', None) | |
| with open(self.log_txt_path, 'w') as fp: | |
| json.dump(nd, fp, indent=2) | |
| def touch_log(self): # listener will kill me if log_txt_path is not updated for 120s | |
| os.utime(self.log_txt_path) # about 2e-6 sec | |
| def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: | |
| d = (OrderedDict if key_ordered else dict)() | |
| # self.as_dict() would contain methods, but we only need variables | |
| for k in self.class_variables.keys(): | |
| if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
| d[k] = getattr(self, k) | |
| return d | |
| def load_state_dict(self, d: Union[OrderedDict, dict, str]): | |
| if isinstance(d, str): # for compatibility with old version | |
| d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l])) | |
| for k in d.keys(): | |
| if k in {'is_large_model', 'gpt_training'}: | |
| continue | |
| try: | |
| setattr(self, k, d[k]) | |
| except Exception as e: | |
| print(f'k={k}, v={d[k]}') | |
| raise e | |
| def set_tf32(tf32: bool): | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.allow_tf32 = bool(tf32) | |
| torch.backends.cuda.matmul.allow_tf32 = bool(tf32) | |
| if hasattr(torch, 'set_float32_matmul_precision'): | |
| torch.set_float32_matmul_precision('high' if tf32 else 'highest') | |
| print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}') | |
| print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}') | |
| print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}') | |
| def __str__(self): | |
| s = [] | |
| for k in self.class_variables.keys(): | |
| if k not in {'device', 'dbg_ks_fp'}: # these are not serializable | |
| s.append(f' {k:20s}: {getattr(self, k)}') | |
| s = '\n'.join(s) | |
| return f'{{\n{s}\n}}\n' | |
| def init_dist_and_get_args(): | |
| for i in range(len(sys.argv)): | |
| if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='): | |
| del sys.argv[i] | |
| break | |
| args = Args(explicit_bool=True).parse_args(known_only=True) | |
| args.chunk_nodes = int(os.environ.get('CK', '') or '0') | |
| if len(args.extra_args) > 0 and args.is_master_node == 0: | |
| print(f'======================================================================================') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') | |
| print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') | |
| print(f'======================================================================================\n\n') | |
| args.set_tf32(args.tf32) | |
| if args.dbg: | |
| torch.autograd.set_detect_anomaly(True) | |
| try: os.makedirs(args.bed, exist_ok=True) | |
| except: pass | |
| try: os.makedirs(args.local_out_path, exist_ok=True) | |
| except: pass | |
| day3 = 60*24*3 | |
| dist.init_distributed_mode(local_out_path=args.local_out_path, fork=False, timeout_minutes=day3 if int(os.environ.get('LONG_DBG', '0') or '0') > 0 else 30) | |
| args.tlen = max(args.tlen, args.nodata_tlen) | |
| if args.zero and args.tema != 0: | |
| args.tema = 0 | |
| print(f'======================================================================================') | |
| print(f'======================== WARNING: args.tema:=0, due to zero={args.zero} ========================') | |
| print(f'======================================================================================\n\n') | |
| if args.nodata: | |
| args.nova = True | |
| if not args.tos_profiler_file_prefix.endswith('/'): args.tos_profiler_file_prefix += '/' | |
| if args.alng < 0: | |
| args.alng = args.aln | |
| args.device = dist.get_device() | |
| args.r_accu = 1 / args.ac # gradient accumulation | |
| args.data_load_reso = None | |
| args.rand |= args.seed is None | |
| args.sche = args.sche or ('lin0' if args.gpt_training else 'cos') | |
| if args.wp == 0: | |
| args.wp = args.ep * 1/100 | |
| di = { | |
| 'b': 'bilinear', 'c': 'bicubic', 'n': 'nearest', 'a': 'area', 'aa': 'area+area', | |
| 'at': 'auto', 'auto': 'auto', | |
| 'v': 'vae', | |
| 'x': 'pix', 'xg': 'pix_glu', 'gx': 'pix_glu', 'g': 'pix_glu' | |
| } | |
| args.ada = args.ada or ('0.9_0.96' if args.gpt_training else '0.5_0.9') | |
| args.dada = args.dada or args.ada | |
| args.opt = args.opt.lower().strip() | |
| if args.lbs: | |
| bs_per_gpu = args.lbs / args.ac | |
| else: | |
| bs_per_gpu = args.bs / args.ac / dist.get_world_size() | |
| bs_per_gpu = round(bs_per_gpu) | |
| args.batch_size = bs_per_gpu | |
| args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() | |
| args.workers = min(args.workers, bs_per_gpu) | |
| args.dblr = args.dblr or args.gblr | |
| args.glr = args.ac * args.gblr * args.glb_batch_size / 256 | |
| args.dlr = args.ac * args.dblr * args.glb_batch_size / 256 | |
| args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 | |
| args.gwde = args.gwde or args.gwd | |
| args.dwde = args.dwde or args.dwd | |
| args.twde = args.twde or args.twd | |
| if args.dbg_modified: | |
| torch.autograd.set_detect_anomaly(True) | |
| args.dbg_ks &= dist.is_local_master() | |
| if args.dbg_ks: | |
| args.dbg_ks_fp = open(os.path.join(args.local_out_path, 'dbg_ks.txt'), 'w') | |
| # gpt args | |
| if args.gpt_training: | |
| assert args.vae_ckpt, 'VAE ckpt must be specified when training GPT' | |
| from infinity.models import alias_dict, alias_dict_inv | |
| if args.model in alias_dict: | |
| args.model = alias_dict[args.model] | |
| args.model_alias = alias_dict_inv[args.model] | |
| else: | |
| args.model_alias = args.model | |
| args.model = f'infinity_{args.model}' | |
| args.task_id = '123' | |
| args.trial_id = '123' | |
| args.robust_run_id = '0' | |
| args.log_txt_path = os.path.join(args.local_out_path, 'log.txt') | |
| ls = '[]' | |
| if 'AUTO_RESUME' in os.environ: | |
| ls.append(int(os.environ['AUTO_RESUME'])) | |
| ls = sorted(ls, reverse=True) | |
| ls = [str(i) for i in ls] | |
| args.ckpt_trials = ls | |
| args.real_trial_id = args.trial_id if len(ls) == 0 else str(ls[-1]) | |
| args.enable_checkpointing = None if args.enable_checkpointing in [False, 0, "0"] else args.enable_checkpointing | |
| args.enable_checkpointing = "full-block" if args.enable_checkpointing in [True, 1, "1"] else args.enable_checkpointing | |
| assert args.enable_checkpointing in [None, "full-block", "full-attn", "self-attn"], \ | |
| f"only support no-checkpointing or full-block/full-attn checkpointing, but got {args.enable_checkpointing}." | |
| if len(args.exp_name) == 0: | |
| args.exp_name = os.path.basename(args.bed) or 'test_exp' | |
| if '-' in args.exp_name: | |
| args.tag, args.exp_name = args.exp_name.split('-', maxsplit=1) | |
| else: | |
| args.tag = 'UK' | |
| if dist.is_master(): | |
| os.system(f'rm -rf {os.path.join(args.bed, "ready-node*")} {os.path.join(args.local_out_path, "ready-node*")}') | |
| if args.sdpa_mem: | |
| from torch.backends.cuda import enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp | |
| enable_flash_sdp(True) | |
| enable_mem_efficient_sdp(True) | |
| enable_math_sdp(False) | |
| return args | |