|
--- |
|
base_model: black-forest-labs/FLUX.1-dev |
|
library_name: diffusers |
|
base_model_relation: quantized |
|
tags: |
|
- quantization |
|
--- |
|
|
|
# Visual comparison of Flux-dev model outputs using BF16 and torchao int4_weight_only quantization |
|
|
|
<td style="text-align: center;"> |
|
BF16<br> |
|
<medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_bf16_combined.png" alt="Flux-dev output with BF16: Baroque, Futurist, Noir styles"></medium-zoom> |
|
</td> |
|
<td style="text-align: center;"> |
|
torchao int4_weight_only<br> |
|
<medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_torchao_4bit_combined.png" alt="torchao int4_weight_only Output"></medium-zoom> |
|
</td> |
|
|
|
# Usage with Diffusers |
|
|
|
To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library: |
|
|
|
``` |
|
pip install -U torchao |
|
``` |
|
|
|
For now, we require this specific branch in diffusers library to fix an error when loading the model |
|
|
|
``` |
|
pip install git+https://github.com/huggingface/diffusers.git@torchao-int4-serialization |
|
``` |
|
|
|
After installing the required library, you can run the following script: |
|
|
|
```python |
|
from diffusers import FluxPipeline |
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
"diffusers/FLUX.1-dev-torchao-int4", |
|
torch_dtype=torch.bfloat16, |
|
use_safetensors=False, |
|
device_map="balanced" |
|
) |
|
|
|
prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase." |
|
|
|
pipe_kwargs = { |
|
"prompt": prompt, |
|
"height": 1024, |
|
"width": 1024, |
|
"guidance_scale": 3.5, |
|
"num_inference_steps": 50, |
|
"max_sequence_length": 512, |
|
} |
|
|
|
image = pipe( |
|
**pipe_kwargs, generator=torch.manual_seed(0), |
|
).images[0] |
|
|
|
image.save("flux.png") |
|
``` |
|
|
|
# How to generate this quantized checkpoint ? |
|
|
|
This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint: |
|
|
|
```python |
|
|
|
import torch |
|
from diffusers import FluxPipeline |
|
from diffusers.quantizers import PipelineQuantizationConfig |
|
from diffusers import TorchAoConfig as DiffusersTorchAoConfig |
|
from transformers import TorchAoConfig as TransformersTorchAoConfig |
|
|
|
pipeline_quant_config = PipelineQuantizationConfig( |
|
quant_mapping={ |
|
"transformer": DiffusersTorchAoConfig("int4_weight_only"), |
|
"text_encoder_2": TransformersTorchAoConfig("int4_weight_only"), |
|
} |
|
) |
|
|
|
pipe = FluxPipeline.from_pretrained( |
|
"black-forest-labs/FLUX.1-dev", |
|
quantization_config=pipeline_quant_config, |
|
torch_dtype=torch.bfloat16, |
|
device_map="balanced" |
|
) |
|
|
|
# safe_serialization set to `False` as we can't save torchao quantized model to safetensors format |
|
pipe.save_pretrained("FLUX.1-dev-torchao-int4", safe_serialization=False) |
|
``` |