Spaces:
Running
on
Zero
Running
on
Zero
Nupur Kumari
commited on
Commit
·
e0f6273
0
Parent(s):
Initial commit
Browse files- .gitattributes +36 -0
- README.md +20 -0
- app.py +264 -0
- imgs/test_cases/action_figure/0.jpg +0 -0
- imgs/test_cases/action_figure/1.jpg +0 -0
- imgs/test_cases/action_figure/2.jpg +0 -0
- imgs/test_cases/penguin/0.jpg +0 -0
- imgs/test_cases/penguin/1.jpg +0 -0
- imgs/test_cases/penguin/2.jpg +0 -0
- imgs/test_cases/rc_car/02.jpg +0 -0
- imgs/test_cases/rc_car/03.jpg +0 -0
- imgs/test_cases/rc_car/04.jpg +0 -0
- models/pytorch_model.bin +3 -0
- pipelines/flux_pipeline/pipeline.py +443 -0
- pipelines/flux_pipeline/transformer.py +756 -0
- requirements.txt +9 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
models/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: SynCD
|
3 |
+
emoji: 🖼
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.4.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
tags:
|
12 |
+
- dwpose
|
13 |
+
- pose
|
14 |
+
- Text-to-Image
|
15 |
+
- Image-to-Image
|
16 |
+
- language models
|
17 |
+
- LLMs
|
18 |
+
short_description: Image generator/identifier/reposer
|
19 |
+
---
|
20 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import spaces
|
7 |
+
import torch
|
8 |
+
from einops import rearrange
|
9 |
+
from huggingface_hub import login
|
10 |
+
from peft import LoraConfig
|
11 |
+
from PIL import Image
|
12 |
+
from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline
|
13 |
+
from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking
|
14 |
+
|
15 |
+
HF_TOKEN = os.getenv('HF_TOKEN')
|
16 |
+
login(token=HF_TOKEN)
|
17 |
+
torch_dtype = torch.bfloat16
|
18 |
+
transformer = FluxTransformer2DModelWithMasking.from_pretrained(
|
19 |
+
'black-forest-labs/FLUX.1-dev',
|
20 |
+
subfolder='transformer',
|
21 |
+
torch_dtype=torch_dtype
|
22 |
+
)
|
23 |
+
pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype)
|
24 |
+
for name, attn_proc in pipeline.transformer.attn_processors.items():
|
25 |
+
attn_proc.name = name
|
26 |
+
|
27 |
+
target_modules=[
|
28 |
+
"to_k",
|
29 |
+
"to_q",
|
30 |
+
"to_v",
|
31 |
+
"add_k_proj",
|
32 |
+
"add_q_proj",
|
33 |
+
"add_v_proj",
|
34 |
+
"to_out.0",
|
35 |
+
"to_add_out",
|
36 |
+
"ff.net.0.proj",
|
37 |
+
"ff.net.2",
|
38 |
+
"ff_context.net.0.proj",
|
39 |
+
"ff_context.net.2",
|
40 |
+
"proj_mlp",
|
41 |
+
"proj_out",
|
42 |
+
]
|
43 |
+
lora_rank = 32
|
44 |
+
lora_config = LoraConfig(
|
45 |
+
r=lora_rank,
|
46 |
+
lora_alpha=lora_rank,
|
47 |
+
init_lora_weights="gaussian",
|
48 |
+
target_modules=target_modules,
|
49 |
+
)
|
50 |
+
pipeline.transformer.add_adapter(lora_config)
|
51 |
+
finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu')
|
52 |
+
transformer_dict = {}
|
53 |
+
for key,value in finetuned_path.items():
|
54 |
+
if 'transformer.base_model.model.' in key:
|
55 |
+
transformer_dict[key.replace('transformer.base_model.model.', '')] = value
|
56 |
+
pipeline.transformer.load_state_dict(transformer_dict, strict=False)
|
57 |
+
# pipeline.to('cuda')
|
58 |
+
pipeline.enable_vae_slicing()
|
59 |
+
pipeline.enable_vae_tiling()
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def decode(latents, pipeline):
|
63 |
+
latents = latents / pipeline.vae.config.scaling_factor
|
64 |
+
image = pipeline.vae.decode(latents, return_dict=False)[0]
|
65 |
+
return image
|
66 |
+
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def encode_target_images(images, pipeline):
|
70 |
+
latents = pipeline.vae.encode(images).latent_dist.sample()
|
71 |
+
latents = latents * pipeline.vae.config.scaling_factor
|
72 |
+
return latents
|
73 |
+
|
74 |
+
|
75 |
+
@spaces.GPU(duration=120)
|
76 |
+
def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload=False):
|
77 |
+
if enable_cpu_offload:
|
78 |
+
pipeline.enable_sequential_cpu_offload()
|
79 |
+
input_images = [img1, img2, img3]
|
80 |
+
# Delete None
|
81 |
+
input_images = [img for img in input_images if img is not None]
|
82 |
+
if len(input_images) == 0:
|
83 |
+
return "Please upload at least one image"
|
84 |
+
numref = len(input_images) + 1
|
85 |
+
images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images])
|
86 |
+
images = images.to(pipeline.device)
|
87 |
+
latents = encode_target_images(images, pipeline)
|
88 |
+
latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0)
|
89 |
+
masklatent = torch.zeros_like(latents)
|
90 |
+
masklatent[:1] = 1.
|
91 |
+
latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref)
|
92 |
+
masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref)
|
93 |
+
B, C, H, W = latents.shape
|
94 |
+
latents = pipeline._pack_latents(latents, B, C, H, W)
|
95 |
+
masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W)
|
96 |
+
output = pipeline(
|
97 |
+
text,
|
98 |
+
latents_ref=latents,
|
99 |
+
latents_mask=masklatent,
|
100 |
+
guidance_scale=guidance_scale,
|
101 |
+
num_inference_steps=inference_steps,
|
102 |
+
height=512,
|
103 |
+
width=numref * 512,
|
104 |
+
generator = torch.Generator(device="cpu").manual_seed(seed),
|
105 |
+
joint_attention_kwargs={'shared_attn': True, 'num': numref},
|
106 |
+
return_dict=False,
|
107 |
+
)[0][0]
|
108 |
+
output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref]
|
109 |
+
img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) )
|
110 |
+
return img
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
def get_example():
|
115 |
+
case = [
|
116 |
+
[
|
117 |
+
"A toy on a beach. Waves in the background. Realistic shot.",
|
118 |
+
"./imgs/test_cases/rc_car/02.jpg",
|
119 |
+
"./imgs/test_cases/rc_car/03.jpg",
|
120 |
+
"./imgs/test_cases/rc_car/04.jpg",
|
121 |
+
3.5,
|
122 |
+
42,
|
123 |
+
True,
|
124 |
+
],
|
125 |
+
[
|
126 |
+
"An action figure on top of a mountain. Sunset in the background. Realistic shot.",
|
127 |
+
"./imgs/test_cases/action_figure/0.jpg",
|
128 |
+
"./imgs/test_cases/action_figure/1.jpg",
|
129 |
+
"./imgs/test_cases/action_figure/2.jpg",
|
130 |
+
3.5,
|
131 |
+
42,
|
132 |
+
True,
|
133 |
+
],
|
134 |
+
[
|
135 |
+
"A penguin plushing wearing pink sunglasses is lounging on a beach. Realistic shot.",
|
136 |
+
"./imgs/test_cases/penguin/0.jpg",
|
137 |
+
"./imgs/test_cases/penguin/1.jpg",
|
138 |
+
"./imgs/test_cases/penguin/2.jpg",
|
139 |
+
3.5,
|
140 |
+
42,
|
141 |
+
True,
|
142 |
+
],
|
143 |
+
]
|
144 |
+
return case
|
145 |
+
|
146 |
+
def run_for_examples(text, img1, img2, img3, guidance_scale, seed, rigid_object, enable_cpu_offload=False):
|
147 |
+
inference_steps = 30
|
148 |
+
|
149 |
+
return generate_image(
|
150 |
+
text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload
|
151 |
+
)
|
152 |
+
|
153 |
+
description = """
|
154 |
+
Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd).
|
155 |
+
|
156 |
+
Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image.
|
157 |
+
|
158 |
+
**HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.**
|
159 |
+
"""
|
160 |
+
|
161 |
+
article = """
|
162 |
+
---
|
163 |
+
**Citation**
|
164 |
+
<br>
|
165 |
+
If you find this repository useful, please consider giving a star ⭐ and a citation
|
166 |
+
```
|
167 |
+
@article{kumari2025syncd,
|
168 |
+
title={Generating Multi-Image Synthetic Data for Text-to-Image Customization},
|
169 |
+
author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh},
|
170 |
+
journal={ArXiv},
|
171 |
+
year={2025}
|
172 |
+
}
|
173 |
+
```
|
174 |
+
**Contact**
|
175 |
+
<br>
|
176 |
+
If you have any questions, please feel free to open an issue or directly reach us out via email.
|
177 |
+
|
178 |
+
**Acknowledgement**
|
179 |
+
<br>
|
180 |
+
This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space.
|
181 |
+
"""
|
182 |
+
|
183 |
+
|
184 |
+
# Gradio
|
185 |
+
with gr.Blocks() as demo:
|
186 |
+
gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]")
|
187 |
+
gr.Markdown(description)
|
188 |
+
with gr.Row():
|
189 |
+
with gr.Column():
|
190 |
+
# text prompt
|
191 |
+
prompt_input = gr.Textbox(
|
192 |
+
label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..."
|
193 |
+
)
|
194 |
+
|
195 |
+
with gr.Row(equal_height=True):
|
196 |
+
# input images
|
197 |
+
image_input_1 = gr.Image(label="img1", type="filepath")
|
198 |
+
image_input_2 = gr.Image(label="img2", type="filepath")
|
199 |
+
image_input_3 = gr.Image(label="img3", type="filepath")
|
200 |
+
|
201 |
+
guidance_scale_input = gr.Slider(
|
202 |
+
label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1
|
203 |
+
)
|
204 |
+
|
205 |
+
num_inference_steps = gr.Slider(
|
206 |
+
label="Inference Steps", minimum=1, maximum=100, value=30, step=1
|
207 |
+
)
|
208 |
+
|
209 |
+
seed_input = gr.Slider(
|
210 |
+
label="Seed", minimum=0, maximum=2147483647, value=42, step=1
|
211 |
+
)
|
212 |
+
|
213 |
+
rigid_object = gr.Checkbox(
|
214 |
+
label="rigid_object", info="Whether its a rigid object or a deformable object like pet animals, wearable etc.", value=True,
|
215 |
+
)
|
216 |
+
enable_cpu_offload = gr.Checkbox(
|
217 |
+
label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False,
|
218 |
+
)
|
219 |
+
|
220 |
+
# generate
|
221 |
+
generate_button = gr.Button("Generate Image")
|
222 |
+
|
223 |
+
|
224 |
+
with gr.Column():
|
225 |
+
# output image
|
226 |
+
output_image = gr.Image(label="Output Image")
|
227 |
+
|
228 |
+
# click
|
229 |
+
generate_button.click(
|
230 |
+
generate_image,
|
231 |
+
inputs=[
|
232 |
+
prompt_input,
|
233 |
+
image_input_1,
|
234 |
+
image_input_2,
|
235 |
+
image_input_3,
|
236 |
+
guidance_scale_input,
|
237 |
+
num_inference_steps,
|
238 |
+
seed_input,
|
239 |
+
rigid_object,
|
240 |
+
enable_cpu_offload,
|
241 |
+
],
|
242 |
+
outputs=output_image,
|
243 |
+
)
|
244 |
+
|
245 |
+
gr.Examples(
|
246 |
+
examples=get_example(),
|
247 |
+
fn=run_for_examples,
|
248 |
+
inputs=[
|
249 |
+
prompt_input,
|
250 |
+
image_input_1,
|
251 |
+
image_input_2,
|
252 |
+
image_input_3,
|
253 |
+
guidance_scale_input,
|
254 |
+
seed_input,
|
255 |
+
rigid_object,
|
256 |
+
],
|
257 |
+
outputs=output_image,
|
258 |
+
)
|
259 |
+
|
260 |
+
gr.Markdown(article)
|
261 |
+
|
262 |
+
# launch
|
263 |
+
demo.launch()
|
264 |
+
|
imgs/test_cases/action_figure/0.jpg
ADDED
![]() |
imgs/test_cases/action_figure/1.jpg
ADDED
![]() |
imgs/test_cases/action_figure/2.jpg
ADDED
![]() |
imgs/test_cases/penguin/0.jpg
ADDED
![]() |
imgs/test_cases/penguin/1.jpg
ADDED
![]() |
imgs/test_cases/penguin/2.jpg
ADDED
![]() |
imgs/test_cases/rc_car/02.jpg
ADDED
![]() |
imgs/test_cases/rc_car/03.jpg
ADDED
![]() |
imgs/test_cases/rc_car/04.jpg
ADDED
![]() |
models/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1947b2008809a98ef1d77b0c98365ccfb9f8b4285873ab3b26cfe43b58a2f4c6
|
3 |
+
size 358868218
|
pipelines/flux_pipeline/pipeline.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import (
|
21 |
+
CLIPImageProcessor,
|
22 |
+
CLIPTextModel,
|
23 |
+
CLIPTokenizer,
|
24 |
+
CLIPVisionModelWithProjection,
|
25 |
+
T5EncoderModel,
|
26 |
+
T5TokenizerFast,
|
27 |
+
)
|
28 |
+
|
29 |
+
from diffusers import FluxPipeline
|
30 |
+
from diffusers.image_processor import VaeImageProcessor
|
31 |
+
from diffusers.loaders import FluxLoraLoaderMixin
|
32 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
33 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
34 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
35 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_xla_available
|
36 |
+
|
37 |
+
if is_torch_xla_available():
|
38 |
+
import torch_xla.core.xla_model as xm
|
39 |
+
|
40 |
+
XLA_AVAILABLE = True
|
41 |
+
else:
|
42 |
+
XLA_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def calculate_shift(
|
47 |
+
image_seq_len,
|
48 |
+
base_seq_len: int = 256,
|
49 |
+
max_seq_len: int = 4096,
|
50 |
+
base_shift: float = 0.5,
|
51 |
+
max_shift: float = 1.16,
|
52 |
+
):
|
53 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
54 |
+
b = base_shift - m * base_seq_len
|
55 |
+
mu = image_seq_len * m + b
|
56 |
+
return mu
|
57 |
+
|
58 |
+
|
59 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
60 |
+
def retrieve_timesteps(
|
61 |
+
scheduler,
|
62 |
+
num_inference_steps: Optional[int] = None,
|
63 |
+
device: Optional[Union[str, torch.device]] = None,
|
64 |
+
timesteps: Optional[List[int]] = None,
|
65 |
+
sigmas: Optional[List[float]] = None,
|
66 |
+
**kwargs,):
|
67 |
+
if timesteps is not None and sigmas is not None:
|
68 |
+
raise ValueError(
|
69 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
70 |
+
)
|
71 |
+
if timesteps is not None:
|
72 |
+
accepts_timesteps = "timesteps" in set(
|
73 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
74 |
+
)
|
75 |
+
if not accepts_timesteps:
|
76 |
+
raise ValueError(
|
77 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
78 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
79 |
+
)
|
80 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
81 |
+
timesteps = scheduler.timesteps
|
82 |
+
num_inference_steps = len(timesteps)
|
83 |
+
elif sigmas is not None:
|
84 |
+
accept_sigmas = "sigmas" in set(
|
85 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
86 |
+
)
|
87 |
+
if not accept_sigmas:
|
88 |
+
raise ValueError(
|
89 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
90 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
91 |
+
)
|
92 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
93 |
+
timesteps = scheduler.timesteps
|
94 |
+
num_inference_steps = len(timesteps)
|
95 |
+
else:
|
96 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
97 |
+
timesteps = scheduler.timesteps
|
98 |
+
return timesteps, num_inference_steps
|
99 |
+
|
100 |
+
|
101 |
+
class SynCDFluxPipeline(FluxPipeline):
|
102 |
+
|
103 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
104 |
+
_optional_components = []
|
105 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
110 |
+
vae: AutoencoderKL,
|
111 |
+
text_encoder: CLIPTextModel,
|
112 |
+
tokenizer: CLIPTokenizer,
|
113 |
+
text_encoder_2: T5EncoderModel,
|
114 |
+
tokenizer_2: T5TokenizerFast,
|
115 |
+
transformer: FluxTransformer2DModel,
|
116 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
117 |
+
feature_extractor: CLIPImageProcessor = None,
|
118 |
+
###
|
119 |
+
num=2,
|
120 |
+
):
|
121 |
+
super().__init__(
|
122 |
+
vae=vae,
|
123 |
+
text_encoder=text_encoder,
|
124 |
+
text_encoder_2=text_encoder_2,
|
125 |
+
tokenizer=tokenizer,
|
126 |
+
tokenizer_2=tokenizer_2,
|
127 |
+
transformer=transformer,
|
128 |
+
scheduler=scheduler,
|
129 |
+
image_encoder=image_encoder,
|
130 |
+
feature_extractor=feature_extractor
|
131 |
+
)
|
132 |
+
self.default_sample_size = 64
|
133 |
+
self.num = num
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def __call__(
|
137 |
+
self,
|
138 |
+
prompt: Union[str, List[str]] = None,
|
139 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
140 |
+
negative_prompt: Union[str, List[str]] = None,
|
141 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
142 |
+
true_cfg_scale: float = 1.0,
|
143 |
+
height: Optional[int] = None,
|
144 |
+
width: Optional[int] = None,
|
145 |
+
num_inference_steps: int = 28,
|
146 |
+
sigmas: Optional[List[float]] = None,
|
147 |
+
guidance_scale: float = 3.5,
|
148 |
+
num_images_per_prompt: Optional[int] = 1,
|
149 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
150 |
+
latents: Optional[torch.FloatTensor] = None,
|
151 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
152 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
153 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
154 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
155 |
+
output_type: Optional[str] = "pil",
|
156 |
+
return_dict: bool = True,
|
157 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
158 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
159 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
160 |
+
max_sequence_length: int = 512,
|
161 |
+
#####
|
162 |
+
latents_ref: Optional[torch.Tensor] = None,
|
163 |
+
latents_mask: Optional[torch.Tensor] = None,
|
164 |
+
return_latents: bool=False,
|
165 |
+
):
|
166 |
+
r"""
|
167 |
+
Function invoked when calling the pipeline for generation.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
prompt (`str` or `List[str]`, *optional*):
|
171 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
172 |
+
instead.
|
173 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
174 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
175 |
+
will be used instead.
|
176 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
177 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
178 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
|
179 |
+
not greater than `1`).
|
180 |
+
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
181 |
+
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
182 |
+
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
183 |
+
true_cfg_scale (`float`, *optional*, defaults to 1.0):
|
184 |
+
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
|
185 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
186 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
187 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
188 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
189 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
190 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
191 |
+
expense of slower inference.
|
192 |
+
sigmas (`List[float]`, *optional*):
|
193 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
194 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
195 |
+
will be used.
|
196 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
197 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
198 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
199 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
200 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
201 |
+
usually at the expense of lower image quality.
|
202 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
203 |
+
The number of images to generate per prompt.
|
204 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
205 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
206 |
+
to make generation deterministic.
|
207 |
+
latents (`torch.FloatTensor`, *optional*):
|
208 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
209 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
210 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
211 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
212 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
213 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
214 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
215 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
216 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
217 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
218 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
219 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
220 |
+
argument.
|
221 |
+
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
222 |
+
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
223 |
+
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
224 |
+
input argument.
|
225 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
226 |
+
The output format of the generate image. Choose between
|
227 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
228 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
229 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
230 |
+
joint_attention_kwargs (`dict`, *optional*):
|
231 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
232 |
+
`self.processor` in
|
233 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
234 |
+
callback_on_step_end (`Callable`, *optional*):
|
235 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
236 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
237 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
238 |
+
`callback_on_step_end_tensor_inputs`.
|
239 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
240 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
241 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
242 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
243 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
244 |
+
|
245 |
+
Examples:
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
249 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
250 |
+
images.
|
251 |
+
"""
|
252 |
+
|
253 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
254 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
255 |
+
|
256 |
+
# 1. Check inputs. Raise error if not correct
|
257 |
+
self.check_inputs(
|
258 |
+
prompt,
|
259 |
+
prompt_2,
|
260 |
+
height,
|
261 |
+
width,
|
262 |
+
negative_prompt=negative_prompt,
|
263 |
+
negative_prompt_2=negative_prompt_2,
|
264 |
+
prompt_embeds=prompt_embeds,
|
265 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
266 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
267 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
268 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
269 |
+
max_sequence_length=max_sequence_length,
|
270 |
+
)
|
271 |
+
|
272 |
+
self._guidance_scale = guidance_scale
|
273 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
274 |
+
self._current_timestep = None
|
275 |
+
self._interrupt = False
|
276 |
+
|
277 |
+
# 2. Define call parameters
|
278 |
+
if prompt is not None and isinstance(prompt, str):
|
279 |
+
batch_size = 1
|
280 |
+
elif prompt is not None and isinstance(prompt, list):
|
281 |
+
batch_size = len(prompt)
|
282 |
+
else:
|
283 |
+
batch_size = prompt_embeds.shape[0]
|
284 |
+
|
285 |
+
device = self._execution_device
|
286 |
+
|
287 |
+
lora_scale = (
|
288 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
289 |
+
)
|
290 |
+
has_neg_prompt = negative_prompt is not None or (
|
291 |
+
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
|
292 |
+
)
|
293 |
+
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
|
294 |
+
(
|
295 |
+
prompt_embeds,
|
296 |
+
pooled_prompt_embeds,
|
297 |
+
text_ids,
|
298 |
+
) = self.encode_prompt(
|
299 |
+
prompt=prompt,
|
300 |
+
prompt_2=prompt_2,
|
301 |
+
prompt_embeds=prompt_embeds,
|
302 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
303 |
+
device=device,
|
304 |
+
num_images_per_prompt=num_images_per_prompt,
|
305 |
+
max_sequence_length=max_sequence_length,
|
306 |
+
lora_scale=lora_scale,
|
307 |
+
)
|
308 |
+
if do_true_cfg:
|
309 |
+
(
|
310 |
+
negative_prompt_embeds,
|
311 |
+
negative_pooled_prompt_embeds,
|
312 |
+
_,
|
313 |
+
) = self.encode_prompt(
|
314 |
+
prompt=negative_prompt,
|
315 |
+
prompt_2=negative_prompt_2,
|
316 |
+
prompt_embeds=negative_prompt_embeds,
|
317 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
318 |
+
device=device,
|
319 |
+
num_images_per_prompt=num_images_per_prompt,
|
320 |
+
max_sequence_length=max_sequence_length,
|
321 |
+
lora_scale=lora_scale,
|
322 |
+
)
|
323 |
+
|
324 |
+
# 4. Prepare latent variables
|
325 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
326 |
+
latents, latent_image_ids = self.prepare_latents(
|
327 |
+
batch_size * num_images_per_prompt,
|
328 |
+
num_channels_latents,
|
329 |
+
height,
|
330 |
+
width,
|
331 |
+
prompt_embeds.dtype,
|
332 |
+
device,
|
333 |
+
generator,
|
334 |
+
latents,
|
335 |
+
)
|
336 |
+
|
337 |
+
# 5. Prepare timesteps
|
338 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
339 |
+
image_seq_len = latents.shape[1]
|
340 |
+
mu = calculate_shift(
|
341 |
+
image_seq_len,
|
342 |
+
self.scheduler.config.get("base_image_seq_len", 256),
|
343 |
+
self.scheduler.config.get("max_image_seq_len", 4096),
|
344 |
+
self.scheduler.config.get("base_shift", 0.5),
|
345 |
+
self.scheduler.config.get("max_shift", 1.15),
|
346 |
+
)
|
347 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
348 |
+
self.scheduler,
|
349 |
+
num_inference_steps,
|
350 |
+
device,
|
351 |
+
sigmas=sigmas,
|
352 |
+
mu=mu,
|
353 |
+
)
|
354 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
355 |
+
self._num_timesteps = len(timesteps)
|
356 |
+
|
357 |
+
# handle guidance
|
358 |
+
if self.transformer.config.guidance_embeds:
|
359 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
360 |
+
guidance = guidance.expand(latents.shape[0])
|
361 |
+
else:
|
362 |
+
guidance = None
|
363 |
+
|
364 |
+
if self.joint_attention_kwargs is None:
|
365 |
+
self._joint_attention_kwargs = {}
|
366 |
+
|
367 |
+
# 6. Denoising loop
|
368 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
369 |
+
for i, t in enumerate(timesteps):
|
370 |
+
if self.interrupt:
|
371 |
+
continue
|
372 |
+
|
373 |
+
self._current_timestep = t
|
374 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
375 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
376 |
+
self.joint_attention_kwargs.update({'timestep': t/1000, 'val': True})
|
377 |
+
if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
|
378 |
+
latents = (1 - latents_mask) * latents_ref + latents_mask * latents
|
379 |
+
|
380 |
+
noise_pred = self.transformer(
|
381 |
+
hidden_states=latents,
|
382 |
+
timestep=timestep / 1000,
|
383 |
+
guidance=guidance,
|
384 |
+
pooled_projections=pooled_prompt_embeds,
|
385 |
+
encoder_hidden_states=prompt_embeds,
|
386 |
+
txt_ids=text_ids,
|
387 |
+
img_ids=latent_image_ids,
|
388 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
389 |
+
return_dict=False,
|
390 |
+
)[0]
|
391 |
+
|
392 |
+
if do_true_cfg:
|
393 |
+
neg_noise_pred = self.transformer(
|
394 |
+
hidden_states=latents,
|
395 |
+
timestep=timestep / 1000,
|
396 |
+
guidance=guidance,
|
397 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
398 |
+
encoder_hidden_states=negative_prompt_embeds,
|
399 |
+
txt_ids=text_ids,
|
400 |
+
img_ids=latent_image_ids,
|
401 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
402 |
+
return_dict=False,
|
403 |
+
)[0]
|
404 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
405 |
+
|
406 |
+
# compute the previous noisy sample x_t -> x_t-1
|
407 |
+
latents_dtype = latents.dtype
|
408 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
409 |
+
|
410 |
+
if latents.dtype != latents_dtype:
|
411 |
+
if torch.backends.mps.is_available():
|
412 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
413 |
+
latents = latents.to(latents_dtype)
|
414 |
+
|
415 |
+
if callback_on_step_end is not None:
|
416 |
+
callback_kwargs = {}
|
417 |
+
for k in callback_on_step_end_tensor_inputs:
|
418 |
+
callback_kwargs[k] = locals()[k]
|
419 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
420 |
+
|
421 |
+
latents = callback_outputs.pop("latents", latents)
|
422 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
423 |
+
|
424 |
+
# call the callback, if provided
|
425 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
426 |
+
progress_bar.update()
|
427 |
+
|
428 |
+
if XLA_AVAILABLE:
|
429 |
+
xm.mark_step()
|
430 |
+
|
431 |
+
self._current_timestep = None
|
432 |
+
|
433 |
+
if output_type == "latent":
|
434 |
+
image = latents
|
435 |
+
else:
|
436 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
437 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
438 |
+
image = self.vae.decode(latents, return_dict=False)
|
439 |
+
|
440 |
+
# Offload all models
|
441 |
+
self.maybe_free_model_hooks()
|
442 |
+
|
443 |
+
return (image,)
|
pipelines/flux_pipeline/transformer.py
ADDED
@@ -0,0 +1,756 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/bghira/SimpleTuner/blob/d0b5f37913a80aabdb0cac893937072dfa3e6a4b/helpers/models/flux/transformer.py#L404
|
2 |
+
# Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
|
3 |
+
#
|
4 |
+
# Originally licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from contextlib import contextmanager
|
9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from einops import rearrange
|
15 |
+
from peft.tuners.lora.layer import LoraLayer
|
16 |
+
|
17 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
18 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
19 |
+
from diffusers.models.attention import FeedForward
|
20 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
21 |
+
from diffusers.models.embeddings import (
|
22 |
+
CombinedTimestepGuidanceTextProjEmbeddings,
|
23 |
+
CombinedTimestepTextProjEmbeddings,
|
24 |
+
FluxPosEmbed,
|
25 |
+
)
|
26 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
28 |
+
from diffusers.models.normalization import (
|
29 |
+
AdaLayerNormContinuous,
|
30 |
+
AdaLayerNormZero,
|
31 |
+
AdaLayerNormZeroSingle,
|
32 |
+
)
|
33 |
+
from diffusers.utils import (
|
34 |
+
USE_PEFT_BACKEND,
|
35 |
+
is_torch_version,
|
36 |
+
logging,
|
37 |
+
scale_lora_layers,
|
38 |
+
unscale_lora_layers,
|
39 |
+
)
|
40 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
41 |
+
|
42 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
+
|
44 |
+
def log_scale_masking(value, min_value=1, max_value=10):
|
45 |
+
# Convert the value into a positive domain for the logarithmic function
|
46 |
+
normalized_value = 1*value
|
47 |
+
|
48 |
+
# Apply logarithmic scaling
|
49 |
+
# log_scaled_value = 1-np.exp(-normalized_value)
|
50 |
+
log_scaled_value = 2.0* math.log(normalized_value+1, 2) / math.log(2, 2) # np.log1p(x) = log(1 + x)
|
51 |
+
# print(log_scaled_value)
|
52 |
+
|
53 |
+
# Rescale to original range
|
54 |
+
scaled_value = log_scaled_value * (max_value - min_value) + min_value
|
55 |
+
|
56 |
+
return min(max_value, int(scaled_value))
|
57 |
+
|
58 |
+
class FluxAttnProcessor2_0:
|
59 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
60 |
+
|
61 |
+
def __init__(self):
|
62 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
63 |
+
raise ImportError(
|
64 |
+
"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
65 |
+
)
|
66 |
+
self.name = None
|
67 |
+
|
68 |
+
def __call__(
|
69 |
+
self,
|
70 |
+
attn: Attention,
|
71 |
+
hidden_states: torch.FloatTensor,
|
72 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
73 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
74 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
75 |
+
shared_attn: bool=False, num=2,
|
76 |
+
mode="a",
|
77 |
+
ref_dict: dict = None,
|
78 |
+
single: bool=False,
|
79 |
+
scale: float = 1.0,
|
80 |
+
timestep: float = 0,
|
81 |
+
val: bool = False,
|
82 |
+
) -> torch.FloatTensor:
|
83 |
+
if mode == 'w': # and single:
|
84 |
+
ref_dict[self.name] = hidden_states.detach()
|
85 |
+
|
86 |
+
batch_size, _, _ = (
|
87 |
+
hidden_states.shape
|
88 |
+
if encoder_hidden_states is None
|
89 |
+
else encoder_hidden_states.shape
|
90 |
+
)
|
91 |
+
end_of_hidden_states = hidden_states.shape[1]
|
92 |
+
text_seq = 512
|
93 |
+
mask = None
|
94 |
+
query = attn.to_q(hidden_states)
|
95 |
+
key = attn.to_k(hidden_states)
|
96 |
+
value = attn.to_v(hidden_states)
|
97 |
+
|
98 |
+
inner_dim = key.shape[-1]
|
99 |
+
head_dim = inner_dim // attn.heads
|
100 |
+
|
101 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
102 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
103 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
104 |
+
|
105 |
+
if attn.norm_q is not None:
|
106 |
+
query = attn.norm_q(query)
|
107 |
+
if attn.norm_k is not None:
|
108 |
+
key = attn.norm_k(key)
|
109 |
+
|
110 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
111 |
+
if encoder_hidden_states is not None:
|
112 |
+
# `context` projections.
|
113 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
114 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
115 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
116 |
+
|
117 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
118 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
119 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
120 |
+
|
121 |
+
if attn.norm_added_q is not None:
|
122 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
123 |
+
if attn.norm_added_k is not None:
|
124 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
125 |
+
|
126 |
+
# attention
|
127 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
128 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
129 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
130 |
+
|
131 |
+
if image_rotary_emb is not None:
|
132 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
133 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
134 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
135 |
+
|
136 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask if timestep < 1. else None)
|
137 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
138 |
+
|
139 |
+
hidden_states = hidden_states.to(query.dtype)
|
140 |
+
|
141 |
+
if encoder_hidden_states is not None:
|
142 |
+
encoder_hidden_states, hidden_states = (
|
143 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
144 |
+
hidden_states[:, encoder_hidden_states.shape[1] : ],
|
145 |
+
)
|
146 |
+
hidden_states = hidden_states[:, :end_of_hidden_states]
|
147 |
+
|
148 |
+
# linear proj
|
149 |
+
hidden_states = attn.to_out[0](hidden_states)
|
150 |
+
# dropout
|
151 |
+
hidden_states = attn.to_out[1](hidden_states)
|
152 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
153 |
+
return hidden_states, encoder_hidden_states
|
154 |
+
else:
|
155 |
+
return hidden_states[:, :end_of_hidden_states]
|
156 |
+
|
157 |
+
|
158 |
+
def expand_flux_attention_mask(
|
159 |
+
hidden_states: torch.Tensor,
|
160 |
+
attn_mask: torch.Tensor,
|
161 |
+
) -> torch.Tensor:
|
162 |
+
"""
|
163 |
+
Expand a mask so that the image is included.
|
164 |
+
"""
|
165 |
+
bsz = attn_mask.shape[0]
|
166 |
+
assert bsz == hidden_states.shape[0]
|
167 |
+
residual_seq_len = hidden_states.shape[1]
|
168 |
+
mask_seq_len = attn_mask.shape[1]
|
169 |
+
|
170 |
+
expanded_mask = torch.ones(bsz, residual_seq_len)
|
171 |
+
expanded_mask[:, :mask_seq_len] = attn_mask
|
172 |
+
|
173 |
+
return expanded_mask
|
174 |
+
|
175 |
+
|
176 |
+
@maybe_allow_in_graph
|
177 |
+
class FluxSingleTransformerBlock(nn.Module):
|
178 |
+
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
|
179 |
+
super().__init__()
|
180 |
+
self.mlp_hidden_dim = int(dim * mlp_ratio)
|
181 |
+
|
182 |
+
self.norm = AdaLayerNormZeroSingle(dim)
|
183 |
+
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
|
184 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
185 |
+
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
186 |
+
|
187 |
+
processor = FluxAttnProcessor2_0()
|
188 |
+
# processor = FluxSingleAttnProcessor3_0()
|
189 |
+
|
190 |
+
self.attn = Attention(
|
191 |
+
query_dim=dim,
|
192 |
+
cross_attention_dim=None,
|
193 |
+
dim_head=attention_head_dim,
|
194 |
+
heads=num_attention_heads,
|
195 |
+
out_dim=dim,
|
196 |
+
bias=True,
|
197 |
+
processor=processor,
|
198 |
+
qk_norm="rms_norm",
|
199 |
+
eps=1e-6,
|
200 |
+
pre_only=True,
|
201 |
+
)
|
202 |
+
|
203 |
+
def forward(
|
204 |
+
self,
|
205 |
+
hidden_states: torch.FloatTensor,
|
206 |
+
temb: torch.FloatTensor,
|
207 |
+
image_rotary_emb=None,
|
208 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
209 |
+
):
|
210 |
+
residual = hidden_states
|
211 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
212 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
213 |
+
|
214 |
+
attn_output = self.attn(
|
215 |
+
hidden_states=norm_hidden_states,
|
216 |
+
image_rotary_emb=image_rotary_emb,
|
217 |
+
**joint_attention_kwargs,
|
218 |
+
single=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
222 |
+
gate = gate.unsqueeze(1)
|
223 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
224 |
+
hidden_states = residual + hidden_states
|
225 |
+
|
226 |
+
return hidden_states
|
227 |
+
|
228 |
+
|
229 |
+
@maybe_allow_in_graph
|
230 |
+
class FluxTransformerBlock(nn.Module):
|
231 |
+
def __init__(
|
232 |
+
self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
|
233 |
+
):
|
234 |
+
super().__init__()
|
235 |
+
|
236 |
+
self.norm1 = AdaLayerNormZero(dim)
|
237 |
+
|
238 |
+
self.norm1_context = AdaLayerNormZero(dim)
|
239 |
+
|
240 |
+
if hasattr(F, "scaled_dot_product_attention"):
|
241 |
+
processor = FluxAttnProcessor2_0()
|
242 |
+
else:
|
243 |
+
raise ValueError(
|
244 |
+
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
245 |
+
)
|
246 |
+
self.attn = Attention(
|
247 |
+
query_dim=dim,
|
248 |
+
cross_attention_dim=None,
|
249 |
+
added_kv_proj_dim=dim,
|
250 |
+
dim_head=attention_head_dim,
|
251 |
+
heads=num_attention_heads,
|
252 |
+
out_dim=dim,
|
253 |
+
context_pre_only=False,
|
254 |
+
bias=True,
|
255 |
+
processor=processor,
|
256 |
+
qk_norm=qk_norm,
|
257 |
+
eps=eps,
|
258 |
+
)
|
259 |
+
|
260 |
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
261 |
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
262 |
+
|
263 |
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
264 |
+
self.ff_context = FeedForward(
|
265 |
+
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
|
266 |
+
)
|
267 |
+
|
268 |
+
# let chunk size default to None
|
269 |
+
self._chunk_size = None
|
270 |
+
self._chunk_dim = 0
|
271 |
+
|
272 |
+
def forward(
|
273 |
+
self,
|
274 |
+
hidden_states: torch.FloatTensor,
|
275 |
+
encoder_hidden_states: torch.FloatTensor,
|
276 |
+
temb: torch.FloatTensor,
|
277 |
+
image_rotary_emb=None,
|
278 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None
|
279 |
+
):
|
280 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
281 |
+
|
282 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (self.norm1_context(encoder_hidden_states, emb=temb))
|
283 |
+
|
284 |
+
# Attention.
|
285 |
+
attn_output, context_attn_output = self.attn(
|
286 |
+
hidden_states=norm_hidden_states,
|
287 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
288 |
+
image_rotary_emb=image_rotary_emb,
|
289 |
+
**joint_attention_kwargs,
|
290 |
+
single=False,
|
291 |
+
)
|
292 |
+
|
293 |
+
# Process attention outputs for the `hidden_states`.
|
294 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
295 |
+
hidden_states = hidden_states + attn_output
|
296 |
+
|
297 |
+
norm_hidden_states = self.norm2(hidden_states)
|
298 |
+
norm_hidden_states = (norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None])
|
299 |
+
|
300 |
+
ff_output = self.ff(norm_hidden_states)
|
301 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
302 |
+
|
303 |
+
hidden_states = hidden_states + ff_output
|
304 |
+
|
305 |
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
306 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
307 |
+
|
308 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
309 |
+
norm_encoder_hidden_states = (
|
310 |
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
311 |
+
+ c_shift_mlp[:, None]
|
312 |
+
)
|
313 |
+
|
314 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
315 |
+
encoder_hidden_states = (
|
316 |
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
317 |
+
)
|
318 |
+
|
319 |
+
return encoder_hidden_states, hidden_states
|
320 |
+
|
321 |
+
|
322 |
+
@contextmanager
|
323 |
+
def set_adapter_scale(model, alpha):
|
324 |
+
original_scaling = {}
|
325 |
+
for module in model.modules():
|
326 |
+
if isinstance(module, LoraLayer):
|
327 |
+
original_scaling[module] = module.scaling.copy()
|
328 |
+
module.scaling = {k: v * alpha for k, v in module.scaling.items()}
|
329 |
+
|
330 |
+
# check whether scaling is prohibited on model
|
331 |
+
# the original scaling dictionary should be empty
|
332 |
+
# if there were no lora layers
|
333 |
+
if not original_scaling:
|
334 |
+
raise ValueError("scaling is only supported for models with `LoraLayer`s")
|
335 |
+
try:
|
336 |
+
yield
|
337 |
+
|
338 |
+
finally:
|
339 |
+
# restore original scaling values after exiting the context
|
340 |
+
for module, scaling in original_scaling.items():
|
341 |
+
module.scaling = scaling
|
342 |
+
|
343 |
+
class FluxTransformer2DModelWithMasking(
|
344 |
+
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
|
345 |
+
):
|
346 |
+
"""
|
347 |
+
The Transformer model introduced in Flux.
|
348 |
+
|
349 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
350 |
+
|
351 |
+
Parameters:
|
352 |
+
patch_size (`int`): Patch size to turn the input data into small patches.
|
353 |
+
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
354 |
+
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
|
355 |
+
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
|
356 |
+
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
357 |
+
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
358 |
+
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
359 |
+
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
360 |
+
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
|
361 |
+
"""
|
362 |
+
|
363 |
+
_supports_gradient_checkpointing = True
|
364 |
+
|
365 |
+
@register_to_config
|
366 |
+
def __init__(
|
367 |
+
self,
|
368 |
+
patch_size: int = 1,
|
369 |
+
in_channels: int = 64,
|
370 |
+
num_layers: int = 19,
|
371 |
+
num_single_layers: int = 38,
|
372 |
+
attention_head_dim: int = 128,
|
373 |
+
num_attention_heads: int = 24,
|
374 |
+
joint_attention_dim: int = 4096,
|
375 |
+
pooled_projection_dim: int = 768,
|
376 |
+
guidance_embeds: bool = False,
|
377 |
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
378 |
+
##
|
379 |
+
):
|
380 |
+
super().__init__()
|
381 |
+
self.out_channels = in_channels
|
382 |
+
self.inner_dim = (
|
383 |
+
self.config.num_attention_heads * self.config.attention_head_dim
|
384 |
+
)
|
385 |
+
|
386 |
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
387 |
+
text_time_guidance_cls = (
|
388 |
+
CombinedTimestepGuidanceTextProjEmbeddings
|
389 |
+
if guidance_embeds
|
390 |
+
else CombinedTimestepTextProjEmbeddings
|
391 |
+
)
|
392 |
+
self.time_text_embed = text_time_guidance_cls(
|
393 |
+
embedding_dim=self.inner_dim,
|
394 |
+
pooled_projection_dim=self.config.pooled_projection_dim,
|
395 |
+
)
|
396 |
+
|
397 |
+
self.context_embedder = nn.Linear(
|
398 |
+
self.config.joint_attention_dim, self.inner_dim
|
399 |
+
)
|
400 |
+
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
|
401 |
+
|
402 |
+
self.transformer_blocks = nn.ModuleList(
|
403 |
+
[
|
404 |
+
FluxTransformerBlock(
|
405 |
+
dim=self.inner_dim,
|
406 |
+
num_attention_heads=self.config.num_attention_heads,
|
407 |
+
attention_head_dim=self.config.attention_head_dim,
|
408 |
+
)
|
409 |
+
for i in range(self.config.num_layers)
|
410 |
+
]
|
411 |
+
)
|
412 |
+
|
413 |
+
self.single_transformer_blocks = nn.ModuleList(
|
414 |
+
[
|
415 |
+
FluxSingleTransformerBlock(
|
416 |
+
dim=self.inner_dim,
|
417 |
+
num_attention_heads=self.config.num_attention_heads,
|
418 |
+
attention_head_dim=self.config.attention_head_dim,
|
419 |
+
)
|
420 |
+
for i in range(self.config.num_single_layers)
|
421 |
+
]
|
422 |
+
)
|
423 |
+
|
424 |
+
self.norm_out = AdaLayerNormContinuous(
|
425 |
+
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
|
426 |
+
)
|
427 |
+
self.proj_out = nn.Linear(
|
428 |
+
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
|
429 |
+
)
|
430 |
+
|
431 |
+
self.gradient_checkpointing = False
|
432 |
+
|
433 |
+
@property
|
434 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
435 |
+
r"""
|
436 |
+
Returns:
|
437 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
438 |
+
indexed by its weight name.
|
439 |
+
"""
|
440 |
+
# set recursively
|
441 |
+
processors = {}
|
442 |
+
|
443 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
444 |
+
if hasattr(module, "get_processor"):
|
445 |
+
processors[f"{name}.processor"] = module.get_processor()
|
446 |
+
|
447 |
+
for sub_name, child in module.named_children():
|
448 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
449 |
+
|
450 |
+
return processors
|
451 |
+
|
452 |
+
for name, module in self.named_children():
|
453 |
+
fn_recursive_add_processors(name, module, processors)
|
454 |
+
|
455 |
+
return processors
|
456 |
+
|
457 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
458 |
+
r"""
|
459 |
+
Sets the attention processor to use to compute attention.
|
460 |
+
|
461 |
+
Parameters:
|
462 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
463 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
464 |
+
for **all** `Attention` layers.
|
465 |
+
|
466 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
467 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
468 |
+
|
469 |
+
"""
|
470 |
+
count = len(self.attn_processors.keys())
|
471 |
+
|
472 |
+
if isinstance(processor, dict) and len(processor) != count:
|
473 |
+
raise ValueError(
|
474 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
475 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
476 |
+
)
|
477 |
+
|
478 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
479 |
+
if hasattr(module, "set_processor"):
|
480 |
+
if not isinstance(processor, dict):
|
481 |
+
module.set_processor(processor)
|
482 |
+
else:
|
483 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
484 |
+
|
485 |
+
for sub_name, child in module.named_children():
|
486 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
487 |
+
|
488 |
+
for name, module in self.named_children():
|
489 |
+
fn_recursive_attn_processor(name, module, processor)
|
490 |
+
|
491 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
492 |
+
if hasattr(module, "gradient_checkpointing"):
|
493 |
+
module.gradient_checkpointing = value
|
494 |
+
|
495 |
+
def forward(
|
496 |
+
self,
|
497 |
+
hidden_states: torch.Tensor,
|
498 |
+
encoder_hidden_states: torch.Tensor = None,
|
499 |
+
pooled_projections: torch.Tensor = None,
|
500 |
+
timestep: torch.LongTensor = None,
|
501 |
+
img_ids: torch.Tensor = None,
|
502 |
+
txt_ids: torch.Tensor = None,
|
503 |
+
guidance: torch.Tensor = None,
|
504 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
505 |
+
return_dict: bool = True,
|
506 |
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
507 |
+
"""
|
508 |
+
The [`FluxTransformer2DModelWithMasking`] forward method.
|
509 |
+
|
510 |
+
Args:
|
511 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
512 |
+
Input `hidden_states`.
|
513 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
514 |
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
515 |
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
516 |
+
from the embeddings of input conditions.
|
517 |
+
timestep ( `torch.LongTensor`):
|
518 |
+
Used to indicate denoising step.
|
519 |
+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
520 |
+
A list of tensors that if specified are added to the residuals of transformer blocks.
|
521 |
+
joint_attention_kwargs (`dict`, *optional*):
|
522 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
523 |
+
`self.processor` in
|
524 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
525 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
526 |
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
527 |
+
tuple.
|
528 |
+
|
529 |
+
Returns:
|
530 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
531 |
+
`tuple` where the first element is the sample tensor.
|
532 |
+
"""
|
533 |
+
if joint_attention_kwargs is not None:
|
534 |
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
535 |
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
536 |
+
else:
|
537 |
+
lora_scale = 1.0
|
538 |
+
|
539 |
+
if USE_PEFT_BACKEND:
|
540 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
541 |
+
scale_lora_layers(self, lora_scale)
|
542 |
+
else:
|
543 |
+
if (
|
544 |
+
joint_attention_kwargs is not None
|
545 |
+
and joint_attention_kwargs.get("scale", None) is not None
|
546 |
+
):
|
547 |
+
logger.warning(
|
548 |
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
549 |
+
)
|
550 |
+
hidden_states = self.x_embedder(hidden_states)
|
551 |
+
|
552 |
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
553 |
+
if guidance is not None:
|
554 |
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
555 |
+
else:
|
556 |
+
guidance = None
|
557 |
+
temb = (
|
558 |
+
self.time_text_embed(timestep, pooled_projections)
|
559 |
+
if guidance is None
|
560 |
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
561 |
+
)
|
562 |
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
563 |
+
|
564 |
+
if txt_ids.ndim == 3:
|
565 |
+
txt_ids = txt_ids[0]
|
566 |
+
if img_ids.ndim == 3:
|
567 |
+
img_ids = img_ids[0]
|
568 |
+
|
569 |
+
|
570 |
+
# txt_ids = torch.zeros((1024,3)).to(txt_ids.device, dtype=txt_ids.dtype)
|
571 |
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
572 |
+
|
573 |
+
image_rotary_emb = self.pos_embed(ids)
|
574 |
+
|
575 |
+
for index_block, block in enumerate(self.transformer_blocks):
|
576 |
+
if self.training and self.gradient_checkpointing:
|
577 |
+
|
578 |
+
def create_custom_forward(module, return_dict=None):
|
579 |
+
def custom_forward(*inputs):
|
580 |
+
if return_dict is not None:
|
581 |
+
return module(*inputs, return_dict=return_dict)
|
582 |
+
else:
|
583 |
+
return module(*inputs)
|
584 |
+
|
585 |
+
return custom_forward
|
586 |
+
|
587 |
+
ckpt_kwargs: Dict[str, Any] = (
|
588 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
589 |
+
)
|
590 |
+
encoder_hidden_states, hidden_states = (
|
591 |
+
torch.utils.checkpoint.checkpoint(
|
592 |
+
create_custom_forward(block),
|
593 |
+
hidden_states,
|
594 |
+
encoder_hidden_states,
|
595 |
+
temb,
|
596 |
+
image_rotary_emb,
|
597 |
+
joint_attention_kwargs,
|
598 |
+
**ckpt_kwargs,
|
599 |
+
)
|
600 |
+
)
|
601 |
+
|
602 |
+
else:
|
603 |
+
encoder_hidden_states, hidden_states = block(
|
604 |
+
hidden_states=hidden_states,
|
605 |
+
encoder_hidden_states=encoder_hidden_states,
|
606 |
+
temb=temb,
|
607 |
+
image_rotary_emb=image_rotary_emb,
|
608 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
609 |
+
)
|
610 |
+
|
611 |
+
# Flux places the text tokens in front of the image tokens in the
|
612 |
+
# sequence.
|
613 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
614 |
+
|
615 |
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
616 |
+
if self.training and self.gradient_checkpointing:
|
617 |
+
|
618 |
+
def create_custom_forward(module, return_dict=None):
|
619 |
+
def custom_forward(*inputs):
|
620 |
+
if return_dict is not None:
|
621 |
+
return module(*inputs, return_dict=return_dict)
|
622 |
+
else:
|
623 |
+
return module(*inputs)
|
624 |
+
|
625 |
+
return custom_forward
|
626 |
+
|
627 |
+
ckpt_kwargs: Dict[str, Any] = (
|
628 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
629 |
+
)
|
630 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
631 |
+
create_custom_forward(block),
|
632 |
+
hidden_states,
|
633 |
+
temb,
|
634 |
+
image_rotary_emb,
|
635 |
+
joint_attention_kwargs,
|
636 |
+
**ckpt_kwargs,
|
637 |
+
)
|
638 |
+
|
639 |
+
else:
|
640 |
+
hidden_states = block(
|
641 |
+
hidden_states=hidden_states,
|
642 |
+
temb=temb,
|
643 |
+
image_rotary_emb=image_rotary_emb,
|
644 |
+
joint_attention_kwargs=joint_attention_kwargs,
|
645 |
+
)
|
646 |
+
|
647 |
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
648 |
+
|
649 |
+
hidden_states = self.norm_out(hidden_states, temb)
|
650 |
+
output = self.proj_out(hidden_states)
|
651 |
+
|
652 |
+
if USE_PEFT_BACKEND:
|
653 |
+
# remove `lora_scale` from each PEFT layer
|
654 |
+
unscale_lora_layers(self, lora_scale)
|
655 |
+
|
656 |
+
if not return_dict:
|
657 |
+
return (output,)
|
658 |
+
|
659 |
+
return Transformer2DModelOutput(sample=output)
|
660 |
+
|
661 |
+
|
662 |
+
if __name__ == "__main__":
|
663 |
+
dtype = torch.bfloat16
|
664 |
+
bsz = 2
|
665 |
+
img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
|
666 |
+
timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
|
667 |
+
pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
|
668 |
+
text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
|
669 |
+
attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
|
670 |
+
"cuda", dtype=dtype
|
671 |
+
) # Last 128 positions are masked
|
672 |
+
|
673 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
674 |
+
latents = latents.view(
|
675 |
+
batch_size, num_channels_latents, height // 2, 2, width // 2, 2
|
676 |
+
)
|
677 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
678 |
+
latents = latents.reshape(
|
679 |
+
batch_size, (height // 2) * (width // 2), num_channels_latents * 4
|
680 |
+
)
|
681 |
+
|
682 |
+
return latents
|
683 |
+
|
684 |
+
def _prepare_latent_image_ids(
|
685 |
+
batch_size, height, width, device="cuda", dtype=dtype
|
686 |
+
):
|
687 |
+
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
688 |
+
latent_image_ids[..., 1] = (
|
689 |
+
latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
690 |
+
)
|
691 |
+
latent_image_ids[..., 2] = (
|
692 |
+
latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
693 |
+
)
|
694 |
+
|
695 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
|
696 |
+
latent_image_ids.shape
|
697 |
+
)
|
698 |
+
|
699 |
+
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
700 |
+
latent_image_ids = latent_image_ids.reshape(
|
701 |
+
batch_size,
|
702 |
+
latent_image_id_height * latent_image_id_width,
|
703 |
+
latent_image_id_channels,
|
704 |
+
)
|
705 |
+
|
706 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
707 |
+
|
708 |
+
txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
|
709 |
+
|
710 |
+
vae_scale_factor = 16
|
711 |
+
height = 2 * (int(512) // vae_scale_factor)
|
712 |
+
width = 2 * (int(512) // vae_scale_factor)
|
713 |
+
img_ids = _prepare_latent_image_ids(bsz, height, width)
|
714 |
+
img = _pack_latents(img, img.shape[0], 16, height, width)
|
715 |
+
|
716 |
+
# Gotta go fast
|
717 |
+
transformer = FluxTransformer2DModelWithMasking.from_config(
|
718 |
+
{
|
719 |
+
"attention_head_dim": 128,
|
720 |
+
"guidance_embeds": True,
|
721 |
+
"in_channels": 64,
|
722 |
+
"joint_attention_dim": 4096,
|
723 |
+
"num_attention_heads": 24,
|
724 |
+
"num_layers": 4,
|
725 |
+
"num_single_layers": 8,
|
726 |
+
"patch_size": 1,
|
727 |
+
"pooled_projection_dim": 768,
|
728 |
+
}
|
729 |
+
).to("cuda", dtype=dtype)
|
730 |
+
|
731 |
+
guidance = torch.tensor([2.0], device="cuda")
|
732 |
+
guidance = guidance.expand(bsz)
|
733 |
+
|
734 |
+
with torch.no_grad():
|
735 |
+
no_mask = transformer(
|
736 |
+
img,
|
737 |
+
encoder_hidden_states=text,
|
738 |
+
pooled_projections=pooled,
|
739 |
+
timestep=timestep,
|
740 |
+
img_ids=img_ids,
|
741 |
+
txt_ids=txt_ids,
|
742 |
+
guidance=guidance,
|
743 |
+
)
|
744 |
+
mask = transformer(
|
745 |
+
img,
|
746 |
+
encoder_hidden_states=text,
|
747 |
+
pooled_projections=pooled,
|
748 |
+
timestep=timestep,
|
749 |
+
img_ids=img_ids,
|
750 |
+
txt_ids=txt_ids,
|
751 |
+
guidance=guidance,
|
752 |
+
attention_mask=attn_mask,
|
753 |
+
)
|
754 |
+
|
755 |
+
assert torch.allclose(no_mask.sample, mask.sample) is False
|
756 |
+
print("Attention masking test ran OK. Differences in output were detected.")
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers
|
2 |
+
torch
|
3 |
+
transformers
|
4 |
+
peft
|
5 |
+
einops
|
6 |
+
numpy
|
7 |
+
Pillow
|
8 |
+
sentencepiece
|
9 |
+
huggingface_hub
|