Nupur Kumari commited on
Commit
e0f6273
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SynCD
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.4.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ tags:
12
+ - dwpose
13
+ - pose
14
+ - Text-to-Image
15
+ - Image-to-Image
16
+ - language models
17
+ - LLMs
18
+ short_description: Image generator/identifier/reposer
19
+ ---
20
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import spaces
7
+ import torch
8
+ from einops import rearrange
9
+ from huggingface_hub import login
10
+ from peft import LoraConfig
11
+ from PIL import Image
12
+ from pipelines.flux_pipeline.pipeline import SynCDFluxPipeline
13
+ from pipelines.flux_pipeline.transformer import FluxTransformer2DModelWithMasking
14
+
15
+ HF_TOKEN = os.getenv('HF_TOKEN')
16
+ login(token=HF_TOKEN)
17
+ torch_dtype = torch.bfloat16
18
+ transformer = FluxTransformer2DModelWithMasking.from_pretrained(
19
+ 'black-forest-labs/FLUX.1-dev',
20
+ subfolder='transformer',
21
+ torch_dtype=torch_dtype
22
+ )
23
+ pipeline = SynCDFluxPipeline.from_pretrained('black-forest-labs/FLUX.1-dev', transformer=transformer, torch_dtype=torch_dtype)
24
+ for name, attn_proc in pipeline.transformer.attn_processors.items():
25
+ attn_proc.name = name
26
+
27
+ target_modules=[
28
+ "to_k",
29
+ "to_q",
30
+ "to_v",
31
+ "add_k_proj",
32
+ "add_q_proj",
33
+ "add_v_proj",
34
+ "to_out.0",
35
+ "to_add_out",
36
+ "ff.net.0.proj",
37
+ "ff.net.2",
38
+ "ff_context.net.0.proj",
39
+ "ff_context.net.2",
40
+ "proj_mlp",
41
+ "proj_out",
42
+ ]
43
+ lora_rank = 32
44
+ lora_config = LoraConfig(
45
+ r=lora_rank,
46
+ lora_alpha=lora_rank,
47
+ init_lora_weights="gaussian",
48
+ target_modules=target_modules,
49
+ )
50
+ pipeline.transformer.add_adapter(lora_config)
51
+ finetuned_path = torch.load('models/pytorch_model.bin', map_location='cpu')
52
+ transformer_dict = {}
53
+ for key,value in finetuned_path.items():
54
+ if 'transformer.base_model.model.' in key:
55
+ transformer_dict[key.replace('transformer.base_model.model.', '')] = value
56
+ pipeline.transformer.load_state_dict(transformer_dict, strict=False)
57
+ # pipeline.to('cuda')
58
+ pipeline.enable_vae_slicing()
59
+ pipeline.enable_vae_tiling()
60
+
61
+ @torch.no_grad()
62
+ def decode(latents, pipeline):
63
+ latents = latents / pipeline.vae.config.scaling_factor
64
+ image = pipeline.vae.decode(latents, return_dict=False)[0]
65
+ return image
66
+
67
+
68
+ @torch.no_grad()
69
+ def encode_target_images(images, pipeline):
70
+ latents = pipeline.vae.encode(images).latent_dist.sample()
71
+ latents = latents * pipeline.vae.config.scaling_factor
72
+ return latents
73
+
74
+
75
+ @spaces.GPU(duration=120)
76
+ def generate_image(text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload=False):
77
+ if enable_cpu_offload:
78
+ pipeline.enable_sequential_cpu_offload()
79
+ input_images = [img1, img2, img3]
80
+ # Delete None
81
+ input_images = [img for img in input_images if img is not None]
82
+ if len(input_images) == 0:
83
+ return "Please upload at least one image"
84
+ numref = len(input_images) + 1
85
+ images = torch.cat([2. * torch.from_numpy(np.array(Image.open(img).convert('RGB').resize((512, 512)))).permute(2, 0, 1).unsqueeze(0).to(torch_dtype)/255. -1. for img in input_images])
86
+ images = images.to(pipeline.device)
87
+ latents = encode_target_images(images, pipeline)
88
+ latents = torch.cat([torch.zeros_like(latents[:1]), latents], dim=0)
89
+ masklatent = torch.zeros_like(latents)
90
+ masklatent[:1] = 1.
91
+ latents = rearrange(latents, "(b n) c h w -> b c h (n w)", n=numref)
92
+ masklatent = rearrange(masklatent, "(b n) c h w -> b c h (n w)", n=numref)
93
+ B, C, H, W = latents.shape
94
+ latents = pipeline._pack_latents(latents, B, C, H, W)
95
+ masklatent = pipeline._pack_latents(masklatent.expand(-1, C, -1, -1) ,B, C, H, W)
96
+ output = pipeline(
97
+ text,
98
+ latents_ref=latents,
99
+ latents_mask=masklatent,
100
+ guidance_scale=guidance_scale,
101
+ num_inference_steps=inference_steps,
102
+ height=512,
103
+ width=numref * 512,
104
+ generator = torch.Generator(device="cpu").manual_seed(seed),
105
+ joint_attention_kwargs={'shared_attn': True, 'num': numref},
106
+ return_dict=False,
107
+ )[0][0]
108
+ output = rearrange(output, "b c h (n w) -> (b n) c h w", n=numref)[::numref]
109
+ img = Image.fromarray( (( torch.clip(output[0].float(), -1., 1.).permute(1,2,0).cpu().numpy()*0.5+0.5)*255).astype(np.uint8) )
110
+ return img
111
+
112
+
113
+
114
+ def get_example():
115
+ case = [
116
+ [
117
+ "A toy on a beach. Waves in the background. Realistic shot.",
118
+ "./imgs/test_cases/rc_car/02.jpg",
119
+ "./imgs/test_cases/rc_car/03.jpg",
120
+ "./imgs/test_cases/rc_car/04.jpg",
121
+ 3.5,
122
+ 42,
123
+ True,
124
+ ],
125
+ [
126
+ "An action figure on top of a mountain. Sunset in the background. Realistic shot.",
127
+ "./imgs/test_cases/action_figure/0.jpg",
128
+ "./imgs/test_cases/action_figure/1.jpg",
129
+ "./imgs/test_cases/action_figure/2.jpg",
130
+ 3.5,
131
+ 42,
132
+ True,
133
+ ],
134
+ [
135
+ "A penguin plushing wearing pink sunglasses is lounging on a beach. Realistic shot.",
136
+ "./imgs/test_cases/penguin/0.jpg",
137
+ "./imgs/test_cases/penguin/1.jpg",
138
+ "./imgs/test_cases/penguin/2.jpg",
139
+ 3.5,
140
+ 42,
141
+ True,
142
+ ],
143
+ ]
144
+ return case
145
+
146
+ def run_for_examples(text, img1, img2, img3, guidance_scale, seed, rigid_object, enable_cpu_offload=False):
147
+ inference_steps = 30
148
+
149
+ return generate_image(
150
+ text, img1, img2, img3, guidance_scale, inference_steps, seed, rigid_object, enable_cpu_offload
151
+ )
152
+
153
+ description = """
154
+ Synthetic Customization Dataset (SynCD) consists of multiple images of the same object in different contexts. We achieve it by promoting similar object identity using either explicit 3D object assets or, more implicitly, using masked shared attention across different views while generating images. Given this training data, we train a new encoder-based model for the task, which can successfully generate new compositions of a reference object using text prompts. You can download our dataset [here](https://huggingface.co/datasets/nupurkmr9/syncd).
155
+
156
+ Our model supports multiple input images of the same object as references. You can upload up to 3 images, with better results on 3 images vs 1 image.
157
+
158
+ **HF Spaces often encounter errors due to quota limitations, so recommend to run it locally.**
159
+ """
160
+
161
+ article = """
162
+ ---
163
+ **Citation**
164
+ <br>
165
+ If you find this repository useful, please consider giving a star ⭐ and a citation
166
+ ```
167
+ @article{kumari2025syncd,
168
+ title={Generating Multi-Image Synthetic Data for Text-to-Image Customization},
169
+ author={Kumari, Nupur and Yin, Xi and Zhu, Jun-Yan and Misra, Ishan and Azadi, Samaneh},
170
+ journal={ArXiv},
171
+ year={2025}
172
+ }
173
+ ```
174
+ **Contact**
175
+ <br>
176
+ If you have any questions, please feel free to open an issue or directly reach us out via email.
177
+
178
+ **Acknowledgement**
179
+ <br>
180
+ This space was modified from [OmniGen](https://huggingface.co/spaces/Shitao/OmniGen) space.
181
+ """
182
+
183
+
184
+ # Gradio
185
+ with gr.Blocks() as demo:
186
+ gr.Markdown("# SynCD: Generating Multi-Image Synthetic Data for Text-to-Image Customization [[paper](https://arxiv.org/abs/2502.01720)] [[code](https://github.com/nupurkmr9/syncd)]")
187
+ gr.Markdown(description)
188
+ with gr.Row():
189
+ with gr.Column():
190
+ # text prompt
191
+ prompt_input = gr.Textbox(
192
+ label="Enter your prompt, more descriptive prompt will lead to better results", placeholder="Type your prompt here..."
193
+ )
194
+
195
+ with gr.Row(equal_height=True):
196
+ # input images
197
+ image_input_1 = gr.Image(label="img1", type="filepath")
198
+ image_input_2 = gr.Image(label="img2", type="filepath")
199
+ image_input_3 = gr.Image(label="img3", type="filepath")
200
+
201
+ guidance_scale_input = gr.Slider(
202
+ label="Guidance Scale", minimum=1.0, maximum=5.0, value=3.5, step=0.1
203
+ )
204
+
205
+ num_inference_steps = gr.Slider(
206
+ label="Inference Steps", minimum=1, maximum=100, value=30, step=1
207
+ )
208
+
209
+ seed_input = gr.Slider(
210
+ label="Seed", minimum=0, maximum=2147483647, value=42, step=1
211
+ )
212
+
213
+ rigid_object = gr.Checkbox(
214
+ label="rigid_object", info="Whether its a rigid object or a deformable object like pet animals, wearable etc.", value=True,
215
+ )
216
+ enable_cpu_offload = gr.Checkbox(
217
+ label="Enable CPU Offload", info="Enable CPU Offload to avoid memory issues", value=False,
218
+ )
219
+
220
+ # generate
221
+ generate_button = gr.Button("Generate Image")
222
+
223
+
224
+ with gr.Column():
225
+ # output image
226
+ output_image = gr.Image(label="Output Image")
227
+
228
+ # click
229
+ generate_button.click(
230
+ generate_image,
231
+ inputs=[
232
+ prompt_input,
233
+ image_input_1,
234
+ image_input_2,
235
+ image_input_3,
236
+ guidance_scale_input,
237
+ num_inference_steps,
238
+ seed_input,
239
+ rigid_object,
240
+ enable_cpu_offload,
241
+ ],
242
+ outputs=output_image,
243
+ )
244
+
245
+ gr.Examples(
246
+ examples=get_example(),
247
+ fn=run_for_examples,
248
+ inputs=[
249
+ prompt_input,
250
+ image_input_1,
251
+ image_input_2,
252
+ image_input_3,
253
+ guidance_scale_input,
254
+ seed_input,
255
+ rigid_object,
256
+ ],
257
+ outputs=output_image,
258
+ )
259
+
260
+ gr.Markdown(article)
261
+
262
+ # launch
263
+ demo.launch()
264
+
imgs/test_cases/action_figure/0.jpg ADDED
imgs/test_cases/action_figure/1.jpg ADDED
imgs/test_cases/action_figure/2.jpg ADDED
imgs/test_cases/penguin/0.jpg ADDED
imgs/test_cases/penguin/1.jpg ADDED
imgs/test_cases/penguin/2.jpg ADDED
imgs/test_cases/rc_car/02.jpg ADDED
imgs/test_cases/rc_car/03.jpg ADDED
imgs/test_cases/rc_car/04.jpg ADDED
models/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1947b2008809a98ef1d77b0c98365ccfb9f8b4285873ab3b26cfe43b58a2f4c6
3
+ size 358868218
pipelines/flux_pipeline/pipeline.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers import FluxPipeline
30
+ from diffusers.image_processor import VaeImageProcessor
31
+ from diffusers.loaders import FluxLoraLoaderMixin
32
+ from diffusers.models.autoencoders import AutoencoderKL
33
+ from diffusers.models.transformers import FluxTransformer2DModel
34
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
35
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_xla_available
36
+
37
+ if is_torch_xla_available():
38
+ import torch_xla.core.xla_model as xm
39
+
40
+ XLA_AVAILABLE = True
41
+ else:
42
+ XLA_AVAILABLE = False
43
+
44
+
45
+
46
+ def calculate_shift(
47
+ image_seq_len,
48
+ base_seq_len: int = 256,
49
+ max_seq_len: int = 4096,
50
+ base_shift: float = 0.5,
51
+ max_shift: float = 1.16,
52
+ ):
53
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
54
+ b = base_shift - m * base_seq_len
55
+ mu = image_seq_len * m + b
56
+ return mu
57
+
58
+
59
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
60
+ def retrieve_timesteps(
61
+ scheduler,
62
+ num_inference_steps: Optional[int] = None,
63
+ device: Optional[Union[str, torch.device]] = None,
64
+ timesteps: Optional[List[int]] = None,
65
+ sigmas: Optional[List[float]] = None,
66
+ **kwargs,):
67
+ if timesteps is not None and sigmas is not None:
68
+ raise ValueError(
69
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
70
+ )
71
+ if timesteps is not None:
72
+ accepts_timesteps = "timesteps" in set(
73
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
74
+ )
75
+ if not accepts_timesteps:
76
+ raise ValueError(
77
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
78
+ f" timestep schedules. Please check whether you are using the correct scheduler."
79
+ )
80
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
81
+ timesteps = scheduler.timesteps
82
+ num_inference_steps = len(timesteps)
83
+ elif sigmas is not None:
84
+ accept_sigmas = "sigmas" in set(
85
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
86
+ )
87
+ if not accept_sigmas:
88
+ raise ValueError(
89
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
90
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
91
+ )
92
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
93
+ timesteps = scheduler.timesteps
94
+ num_inference_steps = len(timesteps)
95
+ else:
96
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
97
+ timesteps = scheduler.timesteps
98
+ return timesteps, num_inference_steps
99
+
100
+
101
+ class SynCDFluxPipeline(FluxPipeline):
102
+
103
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
104
+ _optional_components = []
105
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
106
+
107
+ def __init__(
108
+ self,
109
+ scheduler: FlowMatchEulerDiscreteScheduler,
110
+ vae: AutoencoderKL,
111
+ text_encoder: CLIPTextModel,
112
+ tokenizer: CLIPTokenizer,
113
+ text_encoder_2: T5EncoderModel,
114
+ tokenizer_2: T5TokenizerFast,
115
+ transformer: FluxTransformer2DModel,
116
+ image_encoder: CLIPVisionModelWithProjection = None,
117
+ feature_extractor: CLIPImageProcessor = None,
118
+ ###
119
+ num=2,
120
+ ):
121
+ super().__init__(
122
+ vae=vae,
123
+ text_encoder=text_encoder,
124
+ text_encoder_2=text_encoder_2,
125
+ tokenizer=tokenizer,
126
+ tokenizer_2=tokenizer_2,
127
+ transformer=transformer,
128
+ scheduler=scheduler,
129
+ image_encoder=image_encoder,
130
+ feature_extractor=feature_extractor
131
+ )
132
+ self.default_sample_size = 64
133
+ self.num = num
134
+
135
+ @torch.no_grad()
136
+ def __call__(
137
+ self,
138
+ prompt: Union[str, List[str]] = None,
139
+ prompt_2: Optional[Union[str, List[str]]] = None,
140
+ negative_prompt: Union[str, List[str]] = None,
141
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
142
+ true_cfg_scale: float = 1.0,
143
+ height: Optional[int] = None,
144
+ width: Optional[int] = None,
145
+ num_inference_steps: int = 28,
146
+ sigmas: Optional[List[float]] = None,
147
+ guidance_scale: float = 3.5,
148
+ num_images_per_prompt: Optional[int] = 1,
149
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
150
+ latents: Optional[torch.FloatTensor] = None,
151
+ prompt_embeds: Optional[torch.FloatTensor] = None,
152
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
153
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
154
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
155
+ output_type: Optional[str] = "pil",
156
+ return_dict: bool = True,
157
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
158
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
159
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
160
+ max_sequence_length: int = 512,
161
+ #####
162
+ latents_ref: Optional[torch.Tensor] = None,
163
+ latents_mask: Optional[torch.Tensor] = None,
164
+ return_latents: bool=False,
165
+ ):
166
+ r"""
167
+ Function invoked when calling the pipeline for generation.
168
+
169
+ Args:
170
+ prompt (`str` or `List[str]`, *optional*):
171
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
172
+ instead.
173
+ prompt_2 (`str` or `List[str]`, *optional*):
174
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
175
+ will be used instead.
176
+ negative_prompt (`str` or `List[str]`, *optional*):
177
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
178
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
179
+ not greater than `1`).
180
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
181
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
182
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
183
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
184
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
185
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
186
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
187
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
188
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
189
+ num_inference_steps (`int`, *optional*, defaults to 50):
190
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
191
+ expense of slower inference.
192
+ sigmas (`List[float]`, *optional*):
193
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
194
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
195
+ will be used.
196
+ guidance_scale (`float`, *optional*, defaults to 7.0):
197
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
198
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
199
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
200
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
201
+ usually at the expense of lower image quality.
202
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
203
+ The number of images to generate per prompt.
204
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
205
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
206
+ to make generation deterministic.
207
+ latents (`torch.FloatTensor`, *optional*):
208
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
209
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
210
+ tensor will ge generated by sampling using the supplied random `generator`.
211
+ prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
215
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
216
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
217
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
218
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
219
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
220
+ argument.
221
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
222
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
223
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
224
+ input argument.
225
+ output_type (`str`, *optional*, defaults to `"pil"`):
226
+ The output format of the generate image. Choose between
227
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
228
+ return_dict (`bool`, *optional*, defaults to `True`):
229
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
230
+ joint_attention_kwargs (`dict`, *optional*):
231
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
232
+ `self.processor` in
233
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
234
+ callback_on_step_end (`Callable`, *optional*):
235
+ A function that calls at the end of each denoising steps during the inference. The function is called
236
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
237
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
238
+ `callback_on_step_end_tensor_inputs`.
239
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
240
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
241
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
242
+ `._callback_tensor_inputs` attribute of your pipeline class.
243
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
244
+
245
+ Examples:
246
+
247
+ Returns:
248
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
249
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
250
+ images.
251
+ """
252
+
253
+ height = height or self.default_sample_size * self.vae_scale_factor
254
+ width = width or self.default_sample_size * self.vae_scale_factor
255
+
256
+ # 1. Check inputs. Raise error if not correct
257
+ self.check_inputs(
258
+ prompt,
259
+ prompt_2,
260
+ height,
261
+ width,
262
+ negative_prompt=negative_prompt,
263
+ negative_prompt_2=negative_prompt_2,
264
+ prompt_embeds=prompt_embeds,
265
+ negative_prompt_embeds=negative_prompt_embeds,
266
+ pooled_prompt_embeds=pooled_prompt_embeds,
267
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
268
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
269
+ max_sequence_length=max_sequence_length,
270
+ )
271
+
272
+ self._guidance_scale = guidance_scale
273
+ self._joint_attention_kwargs = joint_attention_kwargs
274
+ self._current_timestep = None
275
+ self._interrupt = False
276
+
277
+ # 2. Define call parameters
278
+ if prompt is not None and isinstance(prompt, str):
279
+ batch_size = 1
280
+ elif prompt is not None and isinstance(prompt, list):
281
+ batch_size = len(prompt)
282
+ else:
283
+ batch_size = prompt_embeds.shape[0]
284
+
285
+ device = self._execution_device
286
+
287
+ lora_scale = (
288
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
289
+ )
290
+ has_neg_prompt = negative_prompt is not None or (
291
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
292
+ )
293
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
294
+ (
295
+ prompt_embeds,
296
+ pooled_prompt_embeds,
297
+ text_ids,
298
+ ) = self.encode_prompt(
299
+ prompt=prompt,
300
+ prompt_2=prompt_2,
301
+ prompt_embeds=prompt_embeds,
302
+ pooled_prompt_embeds=pooled_prompt_embeds,
303
+ device=device,
304
+ num_images_per_prompt=num_images_per_prompt,
305
+ max_sequence_length=max_sequence_length,
306
+ lora_scale=lora_scale,
307
+ )
308
+ if do_true_cfg:
309
+ (
310
+ negative_prompt_embeds,
311
+ negative_pooled_prompt_embeds,
312
+ _,
313
+ ) = self.encode_prompt(
314
+ prompt=negative_prompt,
315
+ prompt_2=negative_prompt_2,
316
+ prompt_embeds=negative_prompt_embeds,
317
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
318
+ device=device,
319
+ num_images_per_prompt=num_images_per_prompt,
320
+ max_sequence_length=max_sequence_length,
321
+ lora_scale=lora_scale,
322
+ )
323
+
324
+ # 4. Prepare latent variables
325
+ num_channels_latents = self.transformer.config.in_channels // 4
326
+ latents, latent_image_ids = self.prepare_latents(
327
+ batch_size * num_images_per_prompt,
328
+ num_channels_latents,
329
+ height,
330
+ width,
331
+ prompt_embeds.dtype,
332
+ device,
333
+ generator,
334
+ latents,
335
+ )
336
+
337
+ # 5. Prepare timesteps
338
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
339
+ image_seq_len = latents.shape[1]
340
+ mu = calculate_shift(
341
+ image_seq_len,
342
+ self.scheduler.config.get("base_image_seq_len", 256),
343
+ self.scheduler.config.get("max_image_seq_len", 4096),
344
+ self.scheduler.config.get("base_shift", 0.5),
345
+ self.scheduler.config.get("max_shift", 1.15),
346
+ )
347
+ timesteps, num_inference_steps = retrieve_timesteps(
348
+ self.scheduler,
349
+ num_inference_steps,
350
+ device,
351
+ sigmas=sigmas,
352
+ mu=mu,
353
+ )
354
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
355
+ self._num_timesteps = len(timesteps)
356
+
357
+ # handle guidance
358
+ if self.transformer.config.guidance_embeds:
359
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
360
+ guidance = guidance.expand(latents.shape[0])
361
+ else:
362
+ guidance = None
363
+
364
+ if self.joint_attention_kwargs is None:
365
+ self._joint_attention_kwargs = {}
366
+
367
+ # 6. Denoising loop
368
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
369
+ for i, t in enumerate(timesteps):
370
+ if self.interrupt:
371
+ continue
372
+
373
+ self._current_timestep = t
374
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
375
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
376
+ self.joint_attention_kwargs.update({'timestep': t/1000, 'val': True})
377
+ if self.joint_attention_kwargs is not None and self.joint_attention_kwargs['shared_attn'] and latents_ref is not None and latents_mask is not None:
378
+ latents = (1 - latents_mask) * latents_ref + latents_mask * latents
379
+
380
+ noise_pred = self.transformer(
381
+ hidden_states=latents,
382
+ timestep=timestep / 1000,
383
+ guidance=guidance,
384
+ pooled_projections=pooled_prompt_embeds,
385
+ encoder_hidden_states=prompt_embeds,
386
+ txt_ids=text_ids,
387
+ img_ids=latent_image_ids,
388
+ joint_attention_kwargs=self.joint_attention_kwargs,
389
+ return_dict=False,
390
+ )[0]
391
+
392
+ if do_true_cfg:
393
+ neg_noise_pred = self.transformer(
394
+ hidden_states=latents,
395
+ timestep=timestep / 1000,
396
+ guidance=guidance,
397
+ pooled_projections=negative_pooled_prompt_embeds,
398
+ encoder_hidden_states=negative_prompt_embeds,
399
+ txt_ids=text_ids,
400
+ img_ids=latent_image_ids,
401
+ joint_attention_kwargs=self.joint_attention_kwargs,
402
+ return_dict=False,
403
+ )[0]
404
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
405
+
406
+ # compute the previous noisy sample x_t -> x_t-1
407
+ latents_dtype = latents.dtype
408
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
409
+
410
+ if latents.dtype != latents_dtype:
411
+ if torch.backends.mps.is_available():
412
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
413
+ latents = latents.to(latents_dtype)
414
+
415
+ if callback_on_step_end is not None:
416
+ callback_kwargs = {}
417
+ for k in callback_on_step_end_tensor_inputs:
418
+ callback_kwargs[k] = locals()[k]
419
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
420
+
421
+ latents = callback_outputs.pop("latents", latents)
422
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
423
+
424
+ # call the callback, if provided
425
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
426
+ progress_bar.update()
427
+
428
+ if XLA_AVAILABLE:
429
+ xm.mark_step()
430
+
431
+ self._current_timestep = None
432
+
433
+ if output_type == "latent":
434
+ image = latents
435
+ else:
436
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
437
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
438
+ image = self.vae.decode(latents, return_dict=False)
439
+
440
+ # Offload all models
441
+ self.maybe_free_model_hooks()
442
+
443
+ return (image,)
pipelines/flux_pipeline/transformer.py ADDED
@@ -0,0 +1,756 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/bghira/SimpleTuner/blob/d0b5f37913a80aabdb0cac893937072dfa3e6a4b/helpers/models/flux/transformer.py#L404
2
+ # Copyright 2024 Stability AI, The HuggingFace Team, The InstantX Team, and Terminus Research Group. All rights reserved.
3
+ #
4
+ # Originally licensed under the Apache License, Version 2.0 (the "License");
5
+ # Updated to "Affero GENERAL PUBLIC LICENSE Version 3, 19 November 2007" via extensive updates to attn_mask usage.
6
+
7
+ import math
8
+ from contextlib import contextmanager
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange
15
+ from peft.tuners.lora.layer import LoraLayer
16
+
17
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
18
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
19
+ from diffusers.models.attention import FeedForward
20
+ from diffusers.models.attention_processor import Attention, AttentionProcessor
21
+ from diffusers.models.embeddings import (
22
+ CombinedTimestepGuidanceTextProjEmbeddings,
23
+ CombinedTimestepTextProjEmbeddings,
24
+ FluxPosEmbed,
25
+ )
26
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.models.normalization import (
29
+ AdaLayerNormContinuous,
30
+ AdaLayerNormZero,
31
+ AdaLayerNormZeroSingle,
32
+ )
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ is_torch_version,
36
+ logging,
37
+ scale_lora_layers,
38
+ unscale_lora_layers,
39
+ )
40
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
41
+
42
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
+
44
+ def log_scale_masking(value, min_value=1, max_value=10):
45
+ # Convert the value into a positive domain for the logarithmic function
46
+ normalized_value = 1*value
47
+
48
+ # Apply logarithmic scaling
49
+ # log_scaled_value = 1-np.exp(-normalized_value)
50
+ log_scaled_value = 2.0* math.log(normalized_value+1, 2) / math.log(2, 2) # np.log1p(x) = log(1 + x)
51
+ # print(log_scaled_value)
52
+
53
+ # Rescale to original range
54
+ scaled_value = log_scaled_value * (max_value - min_value) + min_value
55
+
56
+ return min(max_value, int(scaled_value))
57
+
58
+ class FluxAttnProcessor2_0:
59
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
60
+
61
+ def __init__(self):
62
+ if not hasattr(F, "scaled_dot_product_attention"):
63
+ raise ImportError(
64
+ "FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
65
+ )
66
+ self.name = None
67
+
68
+ def __call__(
69
+ self,
70
+ attn: Attention,
71
+ hidden_states: torch.FloatTensor,
72
+ encoder_hidden_states: torch.FloatTensor = None,
73
+ attention_mask: Optional[torch.FloatTensor] = None,
74
+ image_rotary_emb: Optional[torch.Tensor] = None,
75
+ shared_attn: bool=False, num=2,
76
+ mode="a",
77
+ ref_dict: dict = None,
78
+ single: bool=False,
79
+ scale: float = 1.0,
80
+ timestep: float = 0,
81
+ val: bool = False,
82
+ ) -> torch.FloatTensor:
83
+ if mode == 'w': # and single:
84
+ ref_dict[self.name] = hidden_states.detach()
85
+
86
+ batch_size, _, _ = (
87
+ hidden_states.shape
88
+ if encoder_hidden_states is None
89
+ else encoder_hidden_states.shape
90
+ )
91
+ end_of_hidden_states = hidden_states.shape[1]
92
+ text_seq = 512
93
+ mask = None
94
+ query = attn.to_q(hidden_states)
95
+ key = attn.to_k(hidden_states)
96
+ value = attn.to_v(hidden_states)
97
+
98
+ inner_dim = key.shape[-1]
99
+ head_dim = inner_dim // attn.heads
100
+
101
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
102
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
103
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
104
+
105
+ if attn.norm_q is not None:
106
+ query = attn.norm_q(query)
107
+ if attn.norm_k is not None:
108
+ key = attn.norm_k(key)
109
+
110
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
111
+ if encoder_hidden_states is not None:
112
+ # `context` projections.
113
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
114
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
115
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
116
+
117
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
118
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
119
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
120
+
121
+ if attn.norm_added_q is not None:
122
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
123
+ if attn.norm_added_k is not None:
124
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
125
+
126
+ # attention
127
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
128
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
129
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
130
+
131
+ if image_rotary_emb is not None:
132
+ from diffusers.models.embeddings import apply_rotary_emb
133
+ query = apply_rotary_emb(query, image_rotary_emb)
134
+ key = apply_rotary_emb(key, image_rotary_emb)
135
+
136
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask if timestep < 1. else None)
137
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
138
+
139
+ hidden_states = hidden_states.to(query.dtype)
140
+
141
+ if encoder_hidden_states is not None:
142
+ encoder_hidden_states, hidden_states = (
143
+ hidden_states[:, : encoder_hidden_states.shape[1]],
144
+ hidden_states[:, encoder_hidden_states.shape[1] : ],
145
+ )
146
+ hidden_states = hidden_states[:, :end_of_hidden_states]
147
+
148
+ # linear proj
149
+ hidden_states = attn.to_out[0](hidden_states)
150
+ # dropout
151
+ hidden_states = attn.to_out[1](hidden_states)
152
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
153
+ return hidden_states, encoder_hidden_states
154
+ else:
155
+ return hidden_states[:, :end_of_hidden_states]
156
+
157
+
158
+ def expand_flux_attention_mask(
159
+ hidden_states: torch.Tensor,
160
+ attn_mask: torch.Tensor,
161
+ ) -> torch.Tensor:
162
+ """
163
+ Expand a mask so that the image is included.
164
+ """
165
+ bsz = attn_mask.shape[0]
166
+ assert bsz == hidden_states.shape[0]
167
+ residual_seq_len = hidden_states.shape[1]
168
+ mask_seq_len = attn_mask.shape[1]
169
+
170
+ expanded_mask = torch.ones(bsz, residual_seq_len)
171
+ expanded_mask[:, :mask_seq_len] = attn_mask
172
+
173
+ return expanded_mask
174
+
175
+
176
+ @maybe_allow_in_graph
177
+ class FluxSingleTransformerBlock(nn.Module):
178
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
179
+ super().__init__()
180
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
181
+
182
+ self.norm = AdaLayerNormZeroSingle(dim)
183
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
184
+ self.act_mlp = nn.GELU(approximate="tanh")
185
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
186
+
187
+ processor = FluxAttnProcessor2_0()
188
+ # processor = FluxSingleAttnProcessor3_0()
189
+
190
+ self.attn = Attention(
191
+ query_dim=dim,
192
+ cross_attention_dim=None,
193
+ dim_head=attention_head_dim,
194
+ heads=num_attention_heads,
195
+ out_dim=dim,
196
+ bias=True,
197
+ processor=processor,
198
+ qk_norm="rms_norm",
199
+ eps=1e-6,
200
+ pre_only=True,
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ hidden_states: torch.FloatTensor,
206
+ temb: torch.FloatTensor,
207
+ image_rotary_emb=None,
208
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
209
+ ):
210
+ residual = hidden_states
211
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
212
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
213
+
214
+ attn_output = self.attn(
215
+ hidden_states=norm_hidden_states,
216
+ image_rotary_emb=image_rotary_emb,
217
+ **joint_attention_kwargs,
218
+ single=True,
219
+ )
220
+
221
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
222
+ gate = gate.unsqueeze(1)
223
+ hidden_states = gate * self.proj_out(hidden_states)
224
+ hidden_states = residual + hidden_states
225
+
226
+ return hidden_states
227
+
228
+
229
+ @maybe_allow_in_graph
230
+ class FluxTransformerBlock(nn.Module):
231
+ def __init__(
232
+ self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6
233
+ ):
234
+ super().__init__()
235
+
236
+ self.norm1 = AdaLayerNormZero(dim)
237
+
238
+ self.norm1_context = AdaLayerNormZero(dim)
239
+
240
+ if hasattr(F, "scaled_dot_product_attention"):
241
+ processor = FluxAttnProcessor2_0()
242
+ else:
243
+ raise ValueError(
244
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
245
+ )
246
+ self.attn = Attention(
247
+ query_dim=dim,
248
+ cross_attention_dim=None,
249
+ added_kv_proj_dim=dim,
250
+ dim_head=attention_head_dim,
251
+ heads=num_attention_heads,
252
+ out_dim=dim,
253
+ context_pre_only=False,
254
+ bias=True,
255
+ processor=processor,
256
+ qk_norm=qk_norm,
257
+ eps=eps,
258
+ )
259
+
260
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
261
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
262
+
263
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
264
+ self.ff_context = FeedForward(
265
+ dim=dim, dim_out=dim, activation_fn="gelu-approximate"
266
+ )
267
+
268
+ # let chunk size default to None
269
+ self._chunk_size = None
270
+ self._chunk_dim = 0
271
+
272
+ def forward(
273
+ self,
274
+ hidden_states: torch.FloatTensor,
275
+ encoder_hidden_states: torch.FloatTensor,
276
+ temb: torch.FloatTensor,
277
+ image_rotary_emb=None,
278
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None
279
+ ):
280
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
281
+
282
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = (self.norm1_context(encoder_hidden_states, emb=temb))
283
+
284
+ # Attention.
285
+ attn_output, context_attn_output = self.attn(
286
+ hidden_states=norm_hidden_states,
287
+ encoder_hidden_states=norm_encoder_hidden_states,
288
+ image_rotary_emb=image_rotary_emb,
289
+ **joint_attention_kwargs,
290
+ single=False,
291
+ )
292
+
293
+ # Process attention outputs for the `hidden_states`.
294
+ attn_output = gate_msa.unsqueeze(1) * attn_output
295
+ hidden_states = hidden_states + attn_output
296
+
297
+ norm_hidden_states = self.norm2(hidden_states)
298
+ norm_hidden_states = (norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None])
299
+
300
+ ff_output = self.ff(norm_hidden_states)
301
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
302
+
303
+ hidden_states = hidden_states + ff_output
304
+
305
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
306
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
307
+
308
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
309
+ norm_encoder_hidden_states = (
310
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
311
+ + c_shift_mlp[:, None]
312
+ )
313
+
314
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
315
+ encoder_hidden_states = (
316
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
317
+ )
318
+
319
+ return encoder_hidden_states, hidden_states
320
+
321
+
322
+ @contextmanager
323
+ def set_adapter_scale(model, alpha):
324
+ original_scaling = {}
325
+ for module in model.modules():
326
+ if isinstance(module, LoraLayer):
327
+ original_scaling[module] = module.scaling.copy()
328
+ module.scaling = {k: v * alpha for k, v in module.scaling.items()}
329
+
330
+ # check whether scaling is prohibited on model
331
+ # the original scaling dictionary should be empty
332
+ # if there were no lora layers
333
+ if not original_scaling:
334
+ raise ValueError("scaling is only supported for models with `LoraLayer`s")
335
+ try:
336
+ yield
337
+
338
+ finally:
339
+ # restore original scaling values after exiting the context
340
+ for module, scaling in original_scaling.items():
341
+ module.scaling = scaling
342
+
343
+ class FluxTransformer2DModelWithMasking(
344
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin
345
+ ):
346
+ """
347
+ The Transformer model introduced in Flux.
348
+
349
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
350
+
351
+ Parameters:
352
+ patch_size (`int`): Patch size to turn the input data into small patches.
353
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
354
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
355
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
356
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
357
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
358
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
359
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
360
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
361
+ """
362
+
363
+ _supports_gradient_checkpointing = True
364
+
365
+ @register_to_config
366
+ def __init__(
367
+ self,
368
+ patch_size: int = 1,
369
+ in_channels: int = 64,
370
+ num_layers: int = 19,
371
+ num_single_layers: int = 38,
372
+ attention_head_dim: int = 128,
373
+ num_attention_heads: int = 24,
374
+ joint_attention_dim: int = 4096,
375
+ pooled_projection_dim: int = 768,
376
+ guidance_embeds: bool = False,
377
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
378
+ ##
379
+ ):
380
+ super().__init__()
381
+ self.out_channels = in_channels
382
+ self.inner_dim = (
383
+ self.config.num_attention_heads * self.config.attention_head_dim
384
+ )
385
+
386
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
387
+ text_time_guidance_cls = (
388
+ CombinedTimestepGuidanceTextProjEmbeddings
389
+ if guidance_embeds
390
+ else CombinedTimestepTextProjEmbeddings
391
+ )
392
+ self.time_text_embed = text_time_guidance_cls(
393
+ embedding_dim=self.inner_dim,
394
+ pooled_projection_dim=self.config.pooled_projection_dim,
395
+ )
396
+
397
+ self.context_embedder = nn.Linear(
398
+ self.config.joint_attention_dim, self.inner_dim
399
+ )
400
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
401
+
402
+ self.transformer_blocks = nn.ModuleList(
403
+ [
404
+ FluxTransformerBlock(
405
+ dim=self.inner_dim,
406
+ num_attention_heads=self.config.num_attention_heads,
407
+ attention_head_dim=self.config.attention_head_dim,
408
+ )
409
+ for i in range(self.config.num_layers)
410
+ ]
411
+ )
412
+
413
+ self.single_transformer_blocks = nn.ModuleList(
414
+ [
415
+ FluxSingleTransformerBlock(
416
+ dim=self.inner_dim,
417
+ num_attention_heads=self.config.num_attention_heads,
418
+ attention_head_dim=self.config.attention_head_dim,
419
+ )
420
+ for i in range(self.config.num_single_layers)
421
+ ]
422
+ )
423
+
424
+ self.norm_out = AdaLayerNormContinuous(
425
+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
426
+ )
427
+ self.proj_out = nn.Linear(
428
+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True
429
+ )
430
+
431
+ self.gradient_checkpointing = False
432
+
433
+ @property
434
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
435
+ r"""
436
+ Returns:
437
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
438
+ indexed by its weight name.
439
+ """
440
+ # set recursively
441
+ processors = {}
442
+
443
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
444
+ if hasattr(module, "get_processor"):
445
+ processors[f"{name}.processor"] = module.get_processor()
446
+
447
+ for sub_name, child in module.named_children():
448
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
449
+
450
+ return processors
451
+
452
+ for name, module in self.named_children():
453
+ fn_recursive_add_processors(name, module, processors)
454
+
455
+ return processors
456
+
457
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
458
+ r"""
459
+ Sets the attention processor to use to compute attention.
460
+
461
+ Parameters:
462
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
463
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
464
+ for **all** `Attention` layers.
465
+
466
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
467
+ processor. This is strongly recommended when setting trainable attention processors.
468
+
469
+ """
470
+ count = len(self.attn_processors.keys())
471
+
472
+ if isinstance(processor, dict) and len(processor) != count:
473
+ raise ValueError(
474
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
475
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
476
+ )
477
+
478
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
479
+ if hasattr(module, "set_processor"):
480
+ if not isinstance(processor, dict):
481
+ module.set_processor(processor)
482
+ else:
483
+ module.set_processor(processor.pop(f"{name}.processor"))
484
+
485
+ for sub_name, child in module.named_children():
486
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
487
+
488
+ for name, module in self.named_children():
489
+ fn_recursive_attn_processor(name, module, processor)
490
+
491
+ def _set_gradient_checkpointing(self, module, value=False):
492
+ if hasattr(module, "gradient_checkpointing"):
493
+ module.gradient_checkpointing = value
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states: torch.Tensor,
498
+ encoder_hidden_states: torch.Tensor = None,
499
+ pooled_projections: torch.Tensor = None,
500
+ timestep: torch.LongTensor = None,
501
+ img_ids: torch.Tensor = None,
502
+ txt_ids: torch.Tensor = None,
503
+ guidance: torch.Tensor = None,
504
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
505
+ return_dict: bool = True,
506
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
507
+ """
508
+ The [`FluxTransformer2DModelWithMasking`] forward method.
509
+
510
+ Args:
511
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
512
+ Input `hidden_states`.
513
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
514
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
515
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
516
+ from the embeddings of input conditions.
517
+ timestep ( `torch.LongTensor`):
518
+ Used to indicate denoising step.
519
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
520
+ A list of tensors that if specified are added to the residuals of transformer blocks.
521
+ joint_attention_kwargs (`dict`, *optional*):
522
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
523
+ `self.processor` in
524
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
525
+ return_dict (`bool`, *optional*, defaults to `True`):
526
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
527
+ tuple.
528
+
529
+ Returns:
530
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
531
+ `tuple` where the first element is the sample tensor.
532
+ """
533
+ if joint_attention_kwargs is not None:
534
+ joint_attention_kwargs = joint_attention_kwargs.copy()
535
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
536
+ else:
537
+ lora_scale = 1.0
538
+
539
+ if USE_PEFT_BACKEND:
540
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
541
+ scale_lora_layers(self, lora_scale)
542
+ else:
543
+ if (
544
+ joint_attention_kwargs is not None
545
+ and joint_attention_kwargs.get("scale", None) is not None
546
+ ):
547
+ logger.warning(
548
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
549
+ )
550
+ hidden_states = self.x_embedder(hidden_states)
551
+
552
+ timestep = timestep.to(hidden_states.dtype) * 1000
553
+ if guidance is not None:
554
+ guidance = guidance.to(hidden_states.dtype) * 1000
555
+ else:
556
+ guidance = None
557
+ temb = (
558
+ self.time_text_embed(timestep, pooled_projections)
559
+ if guidance is None
560
+ else self.time_text_embed(timestep, guidance, pooled_projections)
561
+ )
562
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
563
+
564
+ if txt_ids.ndim == 3:
565
+ txt_ids = txt_ids[0]
566
+ if img_ids.ndim == 3:
567
+ img_ids = img_ids[0]
568
+
569
+
570
+ # txt_ids = torch.zeros((1024,3)).to(txt_ids.device, dtype=txt_ids.dtype)
571
+ ids = torch.cat((txt_ids, img_ids), dim=0)
572
+
573
+ image_rotary_emb = self.pos_embed(ids)
574
+
575
+ for index_block, block in enumerate(self.transformer_blocks):
576
+ if self.training and self.gradient_checkpointing:
577
+
578
+ def create_custom_forward(module, return_dict=None):
579
+ def custom_forward(*inputs):
580
+ if return_dict is not None:
581
+ return module(*inputs, return_dict=return_dict)
582
+ else:
583
+ return module(*inputs)
584
+
585
+ return custom_forward
586
+
587
+ ckpt_kwargs: Dict[str, Any] = (
588
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
589
+ )
590
+ encoder_hidden_states, hidden_states = (
591
+ torch.utils.checkpoint.checkpoint(
592
+ create_custom_forward(block),
593
+ hidden_states,
594
+ encoder_hidden_states,
595
+ temb,
596
+ image_rotary_emb,
597
+ joint_attention_kwargs,
598
+ **ckpt_kwargs,
599
+ )
600
+ )
601
+
602
+ else:
603
+ encoder_hidden_states, hidden_states = block(
604
+ hidden_states=hidden_states,
605
+ encoder_hidden_states=encoder_hidden_states,
606
+ temb=temb,
607
+ image_rotary_emb=image_rotary_emb,
608
+ joint_attention_kwargs=joint_attention_kwargs,
609
+ )
610
+
611
+ # Flux places the text tokens in front of the image tokens in the
612
+ # sequence.
613
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
614
+
615
+ for index_block, block in enumerate(self.single_transformer_blocks):
616
+ if self.training and self.gradient_checkpointing:
617
+
618
+ def create_custom_forward(module, return_dict=None):
619
+ def custom_forward(*inputs):
620
+ if return_dict is not None:
621
+ return module(*inputs, return_dict=return_dict)
622
+ else:
623
+ return module(*inputs)
624
+
625
+ return custom_forward
626
+
627
+ ckpt_kwargs: Dict[str, Any] = (
628
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
629
+ )
630
+ hidden_states = torch.utils.checkpoint.checkpoint(
631
+ create_custom_forward(block),
632
+ hidden_states,
633
+ temb,
634
+ image_rotary_emb,
635
+ joint_attention_kwargs,
636
+ **ckpt_kwargs,
637
+ )
638
+
639
+ else:
640
+ hidden_states = block(
641
+ hidden_states=hidden_states,
642
+ temb=temb,
643
+ image_rotary_emb=image_rotary_emb,
644
+ joint_attention_kwargs=joint_attention_kwargs,
645
+ )
646
+
647
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
648
+
649
+ hidden_states = self.norm_out(hidden_states, temb)
650
+ output = self.proj_out(hidden_states)
651
+
652
+ if USE_PEFT_BACKEND:
653
+ # remove `lora_scale` from each PEFT layer
654
+ unscale_lora_layers(self, lora_scale)
655
+
656
+ if not return_dict:
657
+ return (output,)
658
+
659
+ return Transformer2DModelOutput(sample=output)
660
+
661
+
662
+ if __name__ == "__main__":
663
+ dtype = torch.bfloat16
664
+ bsz = 2
665
+ img = torch.rand((bsz, 16, 64, 64)).to("cuda", dtype=dtype)
666
+ timestep = torch.tensor([0.5, 0.5]).to("cuda", dtype=torch.float32)
667
+ pooled = torch.rand(bsz, 768).to("cuda", dtype=dtype)
668
+ text = torch.rand((bsz, 512, 4096)).to("cuda", dtype=dtype)
669
+ attn_mask = torch.tensor([[1.0] * 384 + [0.0] * 128] * bsz).to(
670
+ "cuda", dtype=dtype
671
+ ) # Last 128 positions are masked
672
+
673
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
674
+ latents = latents.view(
675
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
676
+ )
677
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
678
+ latents = latents.reshape(
679
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
680
+ )
681
+
682
+ return latents
683
+
684
+ def _prepare_latent_image_ids(
685
+ batch_size, height, width, device="cuda", dtype=dtype
686
+ ):
687
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
688
+ latent_image_ids[..., 1] = (
689
+ latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
690
+ )
691
+ latent_image_ids[..., 2] = (
692
+ latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
693
+ )
694
+
695
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
696
+ latent_image_ids.shape
697
+ )
698
+
699
+ latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
700
+ latent_image_ids = latent_image_ids.reshape(
701
+ batch_size,
702
+ latent_image_id_height * latent_image_id_width,
703
+ latent_image_id_channels,
704
+ )
705
+
706
+ return latent_image_ids.to(device=device, dtype=dtype)
707
+
708
+ txt_ids = torch.zeros(bsz, text.shape[1], 3).to(device="cuda", dtype=dtype)
709
+
710
+ vae_scale_factor = 16
711
+ height = 2 * (int(512) // vae_scale_factor)
712
+ width = 2 * (int(512) // vae_scale_factor)
713
+ img_ids = _prepare_latent_image_ids(bsz, height, width)
714
+ img = _pack_latents(img, img.shape[0], 16, height, width)
715
+
716
+ # Gotta go fast
717
+ transformer = FluxTransformer2DModelWithMasking.from_config(
718
+ {
719
+ "attention_head_dim": 128,
720
+ "guidance_embeds": True,
721
+ "in_channels": 64,
722
+ "joint_attention_dim": 4096,
723
+ "num_attention_heads": 24,
724
+ "num_layers": 4,
725
+ "num_single_layers": 8,
726
+ "patch_size": 1,
727
+ "pooled_projection_dim": 768,
728
+ }
729
+ ).to("cuda", dtype=dtype)
730
+
731
+ guidance = torch.tensor([2.0], device="cuda")
732
+ guidance = guidance.expand(bsz)
733
+
734
+ with torch.no_grad():
735
+ no_mask = transformer(
736
+ img,
737
+ encoder_hidden_states=text,
738
+ pooled_projections=pooled,
739
+ timestep=timestep,
740
+ img_ids=img_ids,
741
+ txt_ids=txt_ids,
742
+ guidance=guidance,
743
+ )
744
+ mask = transformer(
745
+ img,
746
+ encoder_hidden_states=text,
747
+ pooled_projections=pooled,
748
+ timestep=timestep,
749
+ img_ids=img_ids,
750
+ txt_ids=txt_ids,
751
+ guidance=guidance,
752
+ attention_mask=attn_mask,
753
+ )
754
+
755
+ assert torch.allclose(no_mask.sample, mask.sample) is False
756
+ print("Attention masking test ran OK. Differences in output were detected.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ torch
3
+ transformers
4
+ peft
5
+ einops
6
+ numpy
7
+ Pillow
8
+ sentencepiece
9
+ huggingface_hub