TiM-T2I / config.yaml
GoodEnough's picture
Rename tim_xl_p1_t2i.yaml to config.yaml
9962bec verified
model:
transport:
target: tim.schedulers.transports.OT_FM
params:
P_mean: 0.0
P_std: 1.6
sigma_d: 1.0
unified_dcm_loss:
diffusion_ratio: 0.5
consistency_ratio: 0.1
derivative_type: dde
differential_epsilon: 0.005
weight_time_type: sqrt
weight_time_tangent: True
network:
target: tim.models.t2i.tim_model.TiM
params:
input_size: 16
patch_size: 1
in_channels: 32
depth: 28
hidden_size: 1152
cap_feat_dim: 1152
num_heads: 16
encoder_depth: 8
qk_norm: True
z_dim: 768
new_condition: t-r
use_new_embed: True
distance_aware: True
lora_hidden_size: 384
# pretrained_vae:
vae_dir: mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers
# text encoder
text_encoder_dir: google/gemma-3-1b-it
proportion_empty_prompts: 0.1
use_last_hidden_state: True
max_seq_length: 256
# repa encoder
enc_dir: checkpoints/radio/radio-v2.5-b_half.pth.tar
proj_coeff: 1.0
# ema
use_ema: True
ema_decay: 0.9999
data:
data_type: image_ms
dataset:
root_dir: datasets/t2i_toy_dataset
packed_json: datasets/t2i_toy_dataset/bucket_sampler.json
jsonl_dir: datasets/t2i_toy_dataset/data_info.jsonl
dataloader:
num_workers: 4
batch_size: 128 # Batch size (per device) for the training dataloader.
training:
tracker: null
max_train_steps: 500000
checkpointing_steps: 1000
checkpoints_total_limit: 2
resume_from_checkpoint: latest
learning_rate: 1.0e-4
learning_rate_base_batch_size: 512
scale_lr: True
lr_scheduler: constant # "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_warmup_steps: 0
gradient_accumulation_steps: 1
optimizer:
target: torch.optim.AdamW
params:
# betas: ${tuple:0.9, 0.999}
betas: [0.9, 0.95]
weight_decay: 1.0e-2
eps: 1.0e-6
max_grad_norm: 1.0
proportion_empty_prompts: 0.0
mixed_precision: bf16 # ["no", "fp16", "bf16"]
allow_tf32: True
validation_steps: 500
checkpoint_list: [100000, 200000, 300000, 400000]