marcsun13's picture
marcsun13 HF Staff
Update README.md
9ddd5b1 verified
|
raw
history blame
2.95 kB
metadata
base_model: black-forest-labs/FLUX.1-dev
library_name: diffusers
base_model_relation: quantized
tags:
  - quantization

Visual comparison of Flux-dev model outputs using BF16 and torchao int4_weight_only quantization

BF16
Flux-dev output with BF16: Baroque, Futurist, Noir styles torchao int4_weight_only
torchao int4_weight_only Output

Usage with Diffusers

To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library:

pip install -U torchao

For now, we require this specific branch in diffusers library to fix an error when loading the model

pip install git+https://github.com/huggingface/diffusers.git@torchao-int4-serialization

After installing the required library, you can run the following script:

from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "diffusers/FLUX.1-dev-torchao-int4",
    torch_dtype=torch.bfloat16,
    use_safetensors=False,
    device_map="balanced"
)

prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase."

pipe_kwargs = {
    "prompt": prompt,
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs, generator=torch.manual_seed(0),
).images[0]

image.save("flux.png")

How to generate this quantized checkpoint ?

This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint:


import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import TorchAoConfig as DiffusersTorchAoConfig
from transformers import TorchAoConfig as TransformersTorchAoConfig

pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={
        "transformer": DiffusersTorchAoConfig("int4_weight_only"),
        "text_encoder_2": TransformersTorchAoConfig("int4_weight_only"),
    }
)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=pipeline_quant_config,
    torch_dtype=torch.bfloat16,
    device_map="balanced"
)

# safe_serialization set to `False` as we can't save torchao quantized model to safetensors format
pipe.save_pretrained("FLUX.1-dev-torchao-int4", safe_serialization=False)