YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Marrying Autoregressive Transformer and Diffusion with Multi-Reference Autoregression
Official PyTorch Implementation

arXiv  huggingface  GitHub 

This is a PyTorch/GPU implementation of the paper Marrying Autoregressive Transformer and Diffusion with Multi-Reference Autoregression:

@article{zhen2025marrying,
  title={Marrying Autoregressive Transformer and Diffusion with Multi-Reference Autoregression},
  author={Zhen, Dingcheng and Qiao, Qian and Yu, Tan and Wu, Kangxi and Zhang, Ziwei and Liu, Siyuan and Yin, Shunshun and Tao, Ming},
  journal={arXiv preprint arXiv:2506.09482},
  year={2025}
}

This repo contains:

Preparation

Dataset

Download ImageNet dataset, and place it in your IMAGENET_PATH.

VAE Model

We adopt the VAE model from MAR , you can also get it here.

Installation

Download the code:

git clone https://github.com/TransDiff/TransDiff
cd TransDiff

A suitable conda environment named transdiff can be created and activated with:

conda env create -f environment.yaml
conda activate transdiff

For convenience, our pre-trained TransDiff models can be downloaded directly here as well:

TransDiff Model FID-50K Inception Score #params
TransDiff-B 2.47 244.2 290M
TransDiff-L 2.25 244.3 683M
TransDiff-H 1.69 282.0 1.3B
TransDiff-B MRAR 1.49 282.2 290M
TransDiff-L MRAR 1.61 293.4 683M
TransDiff-H MRAR 1.42 301.2 1.3B
TransDiff-L 512x512 2.51 286.6 683M

(Optional) Download Other Files

Download necessary file and put it into folder fid_stats/, if you want to run evaluation on ImageNet 512x512. Download MRAR index file and put it into root of project, if you want to train TransDiff with MRAR.

(Optional) Caching VAE Latents

Given that our data augmentation consists of simple center cropping and random flipping, the VAE latents can be pre-computed and saved to CACHED_PATH to save computations during TransDiff training:

torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_cache.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 \
--batch_size 128 \
--data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH}

Usage

Demo

Run our interactive visualization demo.

Training

Script for the TransDiff-L 1StepAR setting (Pretrain TransDiff-L with a width of 1024 channels, 800 epochs):

torchrun --nproc_per_node=8 --nnodes=8 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--diffusion_batch_mul 4 \
--epochs 800 --warmup_epochs 100 --blr 1.0e-4 --batch_size 32 \
--output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
--data_path ${IMAGENET_PATH}
  • Training time is ~115h on 64 A100 GPUs with --batch_size 32.
  • Add --online_eval to evaluate FID during training (every 50 epochs).
  • (Optional) To train with cached VAE latents, add --use_cached --cached_path ${CACHED_PATH} to the arguments.
  • (Optional) If the error 'Loss is nan, stopping training' frequently occurs during training when using mixed precision training with 'torch.cuda.amp.autocast()', you can add --bf16 to the arguments.
  • (Optional) If necessary, you can use gradient accumulation by setting --gradient_accumulation_steps n.

Script for the TransDiff-L MRAR setting (Finetune TransDiff-L MRAR with a width of 1024 channels, 40 epochs):

torchrun --nproc_per_node=8 --nnodes=8 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 --mrar --bf16 \
--diffusion_batch_mul 2 \
--epochs 40 --warmup_epochs 10 --lr 5.0e-5 --batch_size 16 --gradient_accumulation_steps 2 \
--output_dir ${OUTPUT_DIR} --resume ${Transdiff-L_1StepAR_DIR} \
--data_path ${IMAGENET_PATH}

Script for the TransDiff-L 512x512 setting (Finetune TransDiff-L 512x512 with a width of 1024 channels, 150 epochs):

torchrun --nproc_per_node=8 --nnodes=8 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main.py \
--img_size 512 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 --ema_rate 0.999 --bf16 \
--diffusion_batch_mul 4 \
--epochs 150 --warmup_epochs 10 --lr 1.0e-4 --batch_size 16 --gradient_accumulation_steps 2 \
--only_train_diff \
--output_dir ${OUTPUT_DIR} --resume ${Transdiff-L_1StepAR_DIR} \
--data_path ${IMAGENET_PATH}

Evaluation (ImageNet 256x256 and 512x512)

Evaluate TransDiff-L 1StepAR with classifier-free guidance:

torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--output_dir ${OUTPUT_DIR} --resume ckpt/transdiff_l/ \
--evaluate --eval_bsz 256 --num_images 50000 \
--cfg 1.3 --scale_0 0.89 --scale_1 0.95

Evaluate TransDiff-L MRAR with classifier-free guidance:

torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main.py \
--img_size 256 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--output_dir ${OUTPUT_DIR} --resume ckpt/transdiff_l_mrar/ \
--evaluate --eval_bsz 256 --num_images 50000 \
--cfg 1.3 --scale_0 0.91 --scale_1 0.93

Evaluate TransDiff-L 512x512 with classifier-free guidance:

torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main.py \
--img_size 512 --vae_path ckpt/vae/kl16.ckpt --vae_embed_dim 16 --patch_size 1 \
--model transdiff_large --diffloss_w 1024 \
--output_dir ${OUTPUT_DIR} --resume ckpt/transdiff_l_512/ \
--evaluate --eval_bsz 64 --num_images 50000 \
--cfg 1.3 --scale_0 0.87 --scale_1 0.87

More settings for Benchmark in paper:

TransDiff Model cfg scale_0 scale_1
TransDiff-B 1.30 0.87 0.91
TransDiff-L 1.30 0.89 0.95
TransDiff-H 1.23 0.87 0.93
TransDiff-B MRAR 1.30 0.87 0.91
TransDiff-L MRAR 1.30 0.91 0.93
TransDiff-H MRAR 1.28 0.87 0.91
TransDiff-L 512x512 1.30 0.87 0.87

Acknowledgements

A large portion of codes in this repo is based on MAR, diffusers and timm.

Contact

If you have any questions, feel free to contact me through email ([email protected]). Enjoy!

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including zhendch/Transdiff