import os import copy from hyimage.common.config import LazyCall as L from hyimage.models.hunyuan.configs.hunyuanimage_config import ( hunyuanimage_v2_1_cfg, hunyuanimage_v2_1_distilled_cfg, hunyuanimage_refiner_cfg, ) from hyimage.models.vae import load_vae from hyimage.common.config.base_config import ( DiTConfig, RepromptConfig, TextEncoderConfig, VAEConfig, ) from hyimage.models.text_encoder import TextEncoder HUNYUANIMAGE_V2_1_MODEL_ROOT = os.environ.get("HUNYUANIMAGE_V2_1_MODEL_ROOT", "./ckpts") # ============================================================================= # MODEL CONFIGURATIONS # ============================================================================= # ============================================================================= # V2.1 MODELS # ============================================================================= def HUNYUANIMAGE_V2_1_TEXT_ENCODER(**kwargs): return TextEncoderConfig( model=L(TextEncoder)( text_encoder_type="llm", max_length=1000, text_encoder_precision='fp16', tokenizer_type="llm", text_encoder_path=None, prompt_template=None, prompt_template_video=None, hidden_state_skip_layer=2, apply_final_norm=False, reproduce=False, logger=None, device=None, ), prompt_template="dit-llm-encode-v2", load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/text_encoder", text_len=1000, ) def HUNYUANIMAGE_V2_1_VAE_32x(**kwargs): return VAEConfig( model=L(load_vae)( vae_path=None, device="cuda", ), load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/vae/vae_2_1", cpu_offload=False, ) def HUNYUANIMAGE_V2_1_DIT(**kwargs): return DiTConfig( model=copy.deepcopy(hunyuanimage_v2_1_cfg), use_lora=False, use_cpu_offload=False, gradient_checkpointing=True, load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1.safetensors", use_compile=True, ) def HUNYUANIMAGE_V2_1_DIT_CFG_DISTILL(**kwargs): return DiTConfig( model=copy.deepcopy(hunyuanimage_v2_1_distilled_cfg), use_lora=False, use_cpu_offload=False, gradient_checkpointing=True, load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage2.1-distilled.safetensors", use_compile=True, ) # ============================================================================= # REFINER MODELS # ============================================================================= def HUNYUANIMAGE_REFINER_DIT(**kwargs): return DiTConfig( model=copy.deepcopy(hunyuanimage_refiner_cfg), use_lora=False, use_cpu_offload=False, gradient_checkpointing=True, load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/dit/hunyuanimage-refiner.safetensors", use_compile=True, ) def HUNYUANIMAGE_REFINER_VAE_32x(**kwargs): return VAEConfig( model=L(load_vae)( vae_path=None, device="cuda", ), load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/vae/vae_refiner", cpu_offload=False, ) def HUNYUANIMAGE_REFINER_TEXT_ENCODER(**kwargs): return TextEncoderConfig( model=L(TextEncoder)( text_encoder_type="llm", max_length=1000, text_encoder_precision='fp16', tokenizer_type="llm", text_encoder_path=None, prompt_template=None, prompt_template_video=None, hidden_state_skip_layer=2, apply_final_norm=False, reproduce=False, logger=None, device=None, ), prompt_template="dit-llm-encode", load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/text_encoder", text_len=256, ) # ============================================================================= # SPECIALIZED MODELS # ============================================================================= def HUNYUANIMAGE_REPROMPT(**kwargs): from hyimage.models.reprompt import RePrompt return RepromptConfig( model=L(RePrompt)( models_root_path=None, device_map="auto", ), load_from=f"{HUNYUANIMAGE_V2_1_MODEL_ROOT}/reprompt", )