Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| #!/usr/bin/python3 | |
| import gc | |
| import os | |
| import os.path as osp | |
| import random | |
| import sys | |
| from copy import deepcopy | |
| from typing import Tuple, Union | |
| import colorama | |
| import torch | |
| import yaml | |
| import infinity.utils.dist as dist | |
| from infinity.models import Infinity | |
| from infinity.models.ema import get_ema_model | |
| from infinity.utils import arg_util, misc | |
| from infinity.utils.misc import os_system | |
| def build_vae_gpt(args: arg_util.Args, vae_st: dict, skip_gpt: bool, force_flash=False, device='cuda'): | |
| if args.vae_type in [8,16,18,20,24,32,64,128]: | |
| from infinity.models.bsq_vae.vae import vae_model | |
| schedule_mode = "dynamic" | |
| codebook_dim = args.vae_type # 18 | |
| codebook_size = 2**codebook_dim | |
| if args.apply_spatial_patchify: | |
| patch_size = 8 | |
| encoder_ch_mult=[1, 2, 4, 4] | |
| decoder_ch_mult=[1, 2, 4, 4] | |
| else: | |
| patch_size = 16 | |
| encoder_ch_mult=[1, 2, 4, 4, 4] | |
| decoder_ch_mult=[1, 2, 4, 4, 4] | |
| vae_local = vae_model(vae_st, schedule_mode, codebook_dim, codebook_size, patch_size=patch_size, | |
| encoder_ch_mult=encoder_ch_mult, decoder_ch_mult=decoder_ch_mult, test_mode=True).to(args.device) | |
| if args.fake_vae_input: | |
| vae_local.encoder = None | |
| vae_local.decoder = None | |
| torch.cuda.empty_cache() | |
| else: | |
| raise ValueError(f"vae_type {args.vae_type} not supported") | |
| if force_flash: args.flash = True | |
| gpt_kw = dict( | |
| pretrained=False, global_pool='', | |
| text_channels=args.Ct5, text_maxlen=args.tlen, | |
| norm_eps=args.norm_eps, rms_norm=args.rms, | |
| shared_aln=args.saln, head_aln=args.haln, | |
| cond_drop_rate=args.cfg, rand_uncond=args.rand_uncond, drop_rate=args.drop, | |
| cross_attn_layer_scale=args.ca_gamma, nm0=args.nm0, tau=args.tau, cos_attn=args.cos, swiglu=args.swi, | |
| raw_scale_schedule=args.scale_schedule, | |
| head_depth=args.dec, | |
| top_p=args.tp, top_k=args.tk, | |
| customized_flash_attn=args.flash, fused_mlp=args.fuse, fused_norm=args.fused_norm, | |
| checkpointing=args.enable_checkpointing, | |
| pad_to_multiplier=args.pad_to_multiplier, | |
| use_flex_attn=args.use_flex_attn, | |
| batch_size=args.batch_size, | |
| add_lvl_embeding_only_first_block=args.add_lvl_embeding_only_first_block, | |
| use_bit_label=args.use_bit_label, | |
| rope2d_each_sa_layer=args.rope2d_each_sa_layer, | |
| rope2d_normalized_by_hw=args.rope2d_normalized_by_hw, | |
| pn=args.pn, | |
| train_h_div_w_list=args.train_h_div_w_list, | |
| always_training_scales=args.always_training_scales, | |
| apply_spatial_patchify=args.apply_spatial_patchify, | |
| ) | |
| if args.dp >= 0: gpt_kw['drop_path_rate'] = args.dp | |
| if args.hd > 0: gpt_kw['num_heads'] = args.hd | |
| print(f'[create gpt_wo_ddp] constructor kw={gpt_kw}\n') | |
| gpt_kw['vae_local'] = vae_local | |
| model_str = args.model.replace('vgpt', 'infinity') # legacy | |
| print(f"{model_str=}") | |
| if model_str.rsplit('c', maxsplit=1)[-1].isdecimal(): | |
| model_str, block_chunks = model_str.rsplit('c', maxsplit=1) | |
| block_chunks = int(block_chunks) | |
| else: | |
| block_chunks = 1 | |
| gpt_kw['block_chunks'] = block_chunks | |
| from infinity.models import Infinity | |
| from timm.models import create_model | |
| gpt_wo_ddp: Infinity = create_model(model_str, **gpt_kw) | |
| if args.use_fsdp_model_ema: | |
| gpt_wo_ddp_ema = get_ema_model(gpt_wo_ddp) | |
| else: | |
| gpt_wo_ddp_ema = None | |
| gpt_wo_ddp = gpt_wo_ddp.to(device) | |
| assert all(not p.requires_grad for p in vae_local.parameters()) | |
| assert all(p.requires_grad for n, p in gpt_wo_ddp.named_parameters()) | |
| return vae_local, gpt_wo_ddp, gpt_wo_ddp_ema | |
| if __name__ == '__main__': | |
| ld(sys.argv[1]) | |
