Diffusers
GMDiTPipeline
Lakonik's picture
Update README.md
29eeed7 verified
metadata
license: apache-2.0
library_name: diffusers
datasets:
  - ILSVRC/imagenet-1k

Gaussian Mixture Flow Matching Models (GMFlow)

Model used in the paper:

Gaussian Mixture Flow Matching Models
Hansheng Chen1, Kai Zhang2, Hao Tan2, Zexiang Xu3, Fujun Luan2, Leonidas Guibas1, Gordon Wetzstein1, Sai Bi2
1Stanford University, 2Adobe Research, 3Hillbot

[arXiv] [GitHub]

Usage

Please first install the official code repository.

We provide a Diffusers pipeline for easy inference. The following code demonstrates how to sample images from the pretrained GM-DiT model using the GM-ODE 2 solver and the GM-SDE 2 solver.

import torch
from huggingface_hub import snapshot_download
from lib.models.diffusions.schedulers import FlowEulerODEScheduler, GMFlowSDEScheduler
from lib.pipelines.gmdit_pipeline import GMDiTPipeline

# Currently the pipeline can only load local checkpoints, so we need to download the checkpoint first
ckpt = snapshot_download(repo_id='Lakonik/gmflow_imagenet_k8_ema')
pipe = GMDiTPipeline.from_pretrained(ckpt, variant='bf16', torch_dtype=torch.bfloat16)
pipe = pipe.to('cuda')

# Pick words that exist in ImageNet
words = ['jay', 'magpie']
class_ids = pipe.get_label_ids(words)

# Sample using GM-ODE 2 solver
pipe.scheduler = FlowEulerODEScheduler.from_config(pipe.scheduler.config)
generator = torch.manual_seed(42)
output = pipe(
    class_labels=class_ids,
    guidance_scale=0.45,
    num_inference_steps=32,
    num_inference_substeps=4,
    output_mode='mean',
    order=2,
    generator=generator)
for i, (word, image) in enumerate(zip(words, output.images)):
    image.save(f'{i:03d}_{word}_gmode2_step32.png')

# Sample using GM-SDE 2 solver (the first run may be slow due to CUDA compilation)
pipe.scheduler = GMFlowSDEScheduler.from_config(pipe.scheduler.config)
generator = torch.manual_seed(42)
output = pipe(
    class_labels=class_ids,
    guidance_scale=0.45,
    num_inference_steps=32,
    num_inference_substeps=1,
    output_mode='sample',
    order=2,
    generator=generator)
for i, (word, image) in enumerate(zip(words, output.images)):
    image.save(f'{i:03d}_{word}_gmsde2_step32.png')

Citation

@misc{gmflow,
      title={Gaussian Mixture Flow Matching Models}, 
      author={Hansheng Chen and Kai Zhang and Hao Tan and Zexiang Xu and Fujun Luan and Leonidas Guibas and Gordon Wetzstein and Sai Bi},
      year={2025},
      eprint={2504.05304},
      archivePrefix={arXiv},
      primaryClass={cs.LG},
      url={https://arxiv.org/abs/2504.05304}, 
}