|
from typing import Any |
|
from typing import Union, Optional |
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
__all__ = ["YakConfig"] |
|
|
|
|
|
class YakConfig(PretrainedConfig): |
|
"""This is the configuration class to store the configuration of an [`YakModel`]. |
|
|
|
Args: |
|
""" |
|
|
|
model_type: str = "yak" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int = 16, |
|
out_channels: int = 16, |
|
vec_in_dim: int = 1536, |
|
context_in_dim: int = 3072, |
|
hidden_size: int = 1536, |
|
mlp_ratio: int = 4, |
|
num_heads: int = 12, |
|
depth: int = 6, |
|
depth_single_blocks: int = 12, |
|
axes_dim: list = [16, 56, 56], |
|
theta: int = 10_000, |
|
qkv_bias: bool = True, |
|
guidance_embed: bool = False, |
|
checkpoint: bool = False, |
|
txt_type: str = "refiner", |
|
timestep_shift: bool = False, |
|
base_shift: float = 0.5, |
|
max_shift: float = 1.15, |
|
vae_config: Optional[Union[PretrainedConfig, dict]] = None, |
|
**kwargs: Any, |
|
): |
|
super().__init__(**kwargs) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.vec_in_dim = vec_in_dim |
|
self.context_in_dim = context_in_dim |
|
self.hidden_size = hidden_size |
|
self.mlp_ratio = mlp_ratio |
|
self.num_heads = num_heads |
|
self.depth = depth |
|
self.depth_single_blocks = depth_single_blocks |
|
self.axes_dim = axes_dim |
|
self.theta = theta |
|
self.qkv_bias = qkv_bias |
|
self.guidance_embed = guidance_embed |
|
self.checkpoint = checkpoint |
|
self.txt_type = txt_type |
|
self.timestep_shift = timestep_shift |
|
self.base_shift = base_shift |
|
self.max_shift = max_shift |
|
|
|
self.vae_config = vae_config |
|
|
|
|
|
|