from typing import Any from typing import Union, Optional from transformers.configuration_utils import PretrainedConfig __all__ = ["YakConfig"] class YakConfig(PretrainedConfig): """This is the configuration class to store the configuration of an [`YakModel`]. Args: """ model_type: str = "yak" def __init__( self, in_channels: int = 16, out_channels: int = 16, vec_in_dim: int = 1536, context_in_dim: int = 3072, hidden_size: int = 1536, mlp_ratio: int = 4, num_heads: int = 12, depth: int = 6, depth_single_blocks: int = 12, axes_dim: list = [16, 56, 56], theta: int = 10_000, qkv_bias: bool = True, guidance_embed: bool = False, checkpoint: bool = False, txt_type: str = "refiner", timestep_shift: bool = False, base_shift: float = 0.5, max_shift: float = 1.15, vae_config: Optional[Union[PretrainedConfig, dict]] = None, **kwargs: Any, ): super().__init__(**kwargs) self.in_channels = in_channels self.out_channels = out_channels self.vec_in_dim = vec_in_dim self.context_in_dim = context_in_dim self.hidden_size = hidden_size self.mlp_ratio = mlp_ratio self.num_heads = num_heads self.depth = depth self.depth_single_blocks = depth_single_blocks self.axes_dim = axes_dim self.theta = theta self.qkv_bias = qkv_bias self.guidance_embed = guidance_embed self.checkpoint = checkpoint self.txt_type = txt_type self.timestep_shift = timestep_shift self.base_shift = base_shift self.max_shift = max_shift self.vae_config = vae_config