File size: 2,953 Bytes
8b035b6
9ddd5b1
8b035b6
9ddd5b1
 
 
8b035b6
 
9ddd5b1
8b035b6
9ddd5b1
 
 
 
 
 
 
 
8b035b6
9ddd5b1
8b035b6
9ddd5b1
8b035b6
9ddd5b1
 
 
8b035b6
9ddd5b1
8b035b6
9ddd5b1
 
 
8b035b6
9ddd5b1
8b035b6
9ddd5b1
 
8b035b6
9ddd5b1
 
 
 
 
 
8b035b6
9ddd5b1
8b035b6
9ddd5b1
 
 
 
 
 
 
 
8b035b6
9ddd5b1
 
 
8b035b6
9ddd5b1
 
8b035b6
9ddd5b1
8b035b6
9ddd5b1
8b035b6
9ddd5b1
8b035b6
9ddd5b1
 
 
 
 
8b035b6
9ddd5b1
 
 
 
 
 
8b035b6
9ddd5b1
 
 
 
 
 
8b035b6
9ddd5b1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
---
base_model: black-forest-labs/FLUX.1-dev
library_name: diffusers
base_model_relation: quantized
tags:
- quantization
---

# Visual comparison of Flux-dev model outputs using BF16 and torchao int4_weight_only quantization

<td style="text-align: center;">
  BF16<br>
  <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_bf16_combined.png" alt="Flux-dev output with BF16: Baroque, Futurist, Noir styles"></medium-zoom>
</td>
<td style="text-align: center;">
  torchao int4_weight_only<br>
  <medium-zoom background="rgba(0,0,0,.7)"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/quantization-backends-diffusers/combined_flux-dev_torchao_4bit_combined.png" alt="torchao int4_weight_only Output"></medium-zoom>
</td>

# Usage with Diffusers

To use this quantized FLUX.1 [dev] checkpoint, you need to install the 🧨 diffusers and torchao library:

```
pip install -U torchao
```

For now, we require this specific branch in diffusers library to fix an error when loading the model 

```
pip install git+https://github.com/huggingface/diffusers.git@torchao-int4-serialization
```

After installing the required library, you can run the following script: 

```python
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "diffusers/FLUX.1-dev-torchao-int4",
    torch_dtype=torch.bfloat16,
    use_safetensors=False,
    device_map="balanced"
)

prompt = "Baroque style, a lavish palace interior with ornate gilded ceilings, intricate tapestries, and dramatic lighting over a grand staircase."

pipe_kwargs = {
    "prompt": prompt,
    "height": 1024,
    "width": 1024,
    "guidance_scale": 3.5,
    "num_inference_steps": 50,
    "max_sequence_length": 512,
}

image = pipe(
    **pipe_kwargs, generator=torch.manual_seed(0),
).images[0]

image.save("flux.png")
```

# How to generate this quantized checkpoint ? 

This checkpoint was created with the following script using "black-forest-labs/FLUX.1-dev" checkpoint:

```python

import torch
from diffusers import FluxPipeline
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers import TorchAoConfig as DiffusersTorchAoConfig
from transformers import TorchAoConfig as TransformersTorchAoConfig

pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={
        "transformer": DiffusersTorchAoConfig("int4_weight_only"),
        "text_encoder_2": TransformersTorchAoConfig("int4_weight_only"),
    }
)

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    quantization_config=pipeline_quant_config,
    torch_dtype=torch.bfloat16,
    device_map="balanced"
)

# safe_serialization set to `False` as we can't save torchao quantized model to safetensors format
pipe.save_pretrained("FLUX.1-dev-torchao-int4", safe_serialization=False)
```