ABDALLALSWAITI commited on
Commit
d62c734
·
verified ·
1 Parent(s): a662214

Upload FP8 quantized model

Browse files
Files changed (1) hide show
  1. pipeline_flux_controlnet.py +1181 -0
pipeline_flux_controlnet.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+
33
+ # from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
34
+ from controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
35
+
36
+ from diffusers.models.transformers import FluxTransformer2DModel
37
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
38
+ from diffusers.utils import (
39
+ USE_PEFT_BACKEND,
40
+ is_torch_xla_available,
41
+ logging,
42
+ replace_example_docstring,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import randn_tensor
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
49
+
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+
59
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
+
61
+ EXAMPLE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> import torch
65
+ >>> from diffusers.utils import load_image
66
+ >>> from diffusers import FluxControlNetPipeline
67
+ >>> from diffusers import FluxControlNetModel
68
+
69
+ >>> base_model = "black-forest-labs/FLUX.1-dev"
70
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
71
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
72
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
73
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
74
+ ... )
75
+ >>> pipe.to("cuda")
76
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
77
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
78
+ >>> image = pipe(
79
+ ... prompt,
80
+ ... control_image=control_image,
81
+ ... control_guidance_start=0.2,
82
+ ... control_guidance_end=0.8,
83
+ ... controlnet_conditioning_scale=1.0,
84
+ ... num_inference_steps=28,
85
+ ... guidance_scale=3.5,
86
+ ... ).images[0]
87
+ >>> image.save("flux.png")
88
+ ```
89
+ """
90
+
91
+
92
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
93
+ def calculate_shift(
94
+ image_seq_len,
95
+ base_seq_len: int = 256,
96
+ max_seq_len: int = 4096,
97
+ base_shift: float = 0.5,
98
+ max_shift: float = 1.15,
99
+ ):
100
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
101
+ b = base_shift - m * base_seq_len
102
+ mu = image_seq_len * m + b
103
+ return mu
104
+
105
+
106
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
107
+ def retrieve_latents(
108
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
109
+ ):
110
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
111
+ return encoder_output.latent_dist.sample(generator)
112
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
113
+ return encoder_output.latent_dist.mode()
114
+ elif hasattr(encoder_output, "latents"):
115
+ return encoder_output.latents
116
+ else:
117
+ raise AttributeError("Could not access latents of provided encoder_output")
118
+
119
+
120
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
121
+ def retrieve_timesteps(
122
+ scheduler,
123
+ num_inference_steps: Optional[int] = None,
124
+ device: Optional[Union[str, torch.device]] = None,
125
+ timesteps: Optional[List[int]] = None,
126
+ sigmas: Optional[List[float]] = None,
127
+ **kwargs,
128
+ ):
129
+ r"""
130
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
131
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
132
+
133
+ Args:
134
+ scheduler (`SchedulerMixin`):
135
+ The scheduler to get timesteps from.
136
+ num_inference_steps (`int`):
137
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
138
+ must be `None`.
139
+ device (`str` or `torch.device`, *optional*):
140
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
141
+ timesteps (`List[int]`, *optional*):
142
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
143
+ `num_inference_steps` and `sigmas` must be `None`.
144
+ sigmas (`List[float]`, *optional*):
145
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
146
+ `num_inference_steps` and `timesteps` must be `None`.
147
+
148
+ Returns:
149
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
150
+ second element is the number of inference steps.
151
+ """
152
+ if timesteps is not None and sigmas is not None:
153
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
154
+ if timesteps is not None:
155
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
156
+ if not accepts_timesteps:
157
+ raise ValueError(
158
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
159
+ f" timestep schedules. Please check whether you are using the correct scheduler."
160
+ )
161
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
162
+ timesteps = scheduler.timesteps
163
+ num_inference_steps = len(timesteps)
164
+ elif sigmas is not None:
165
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
166
+ if not accept_sigmas:
167
+ raise ValueError(
168
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
169
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
170
+ )
171
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
172
+ timesteps = scheduler.timesteps
173
+ num_inference_steps = len(timesteps)
174
+ else:
175
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
176
+ timesteps = scheduler.timesteps
177
+ return timesteps, num_inference_steps
178
+
179
+
180
+ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
181
+ r"""
182
+ The Flux pipeline for text-to-image generation.
183
+
184
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
185
+
186
+ Args:
187
+ transformer ([`FluxTransformer2DModel`]):
188
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
189
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
190
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
191
+ vae ([`AutoencoderKL`]):
192
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
193
+ text_encoder ([`CLIPTextModel`]):
194
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
195
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
196
+ text_encoder_2 ([`T5EncoderModel`]):
197
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
198
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
199
+ tokenizer (`CLIPTokenizer`):
200
+ Tokenizer of class
201
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
202
+ tokenizer_2 (`T5TokenizerFast`):
203
+ Second Tokenizer of class
204
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
205
+ """
206
+
207
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
208
+ _optional_components = ["image_encoder", "feature_extractor"]
209
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"]
210
+
211
+ def __init__(
212
+ self,
213
+ scheduler: FlowMatchEulerDiscreteScheduler,
214
+ vae: AutoencoderKL,
215
+ text_encoder: CLIPTextModel,
216
+ tokenizer: CLIPTokenizer,
217
+ text_encoder_2: T5EncoderModel,
218
+ tokenizer_2: T5TokenizerFast,
219
+ transformer: FluxTransformer2DModel,
220
+ controlnet: Union[
221
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
222
+ ],
223
+ image_encoder: CLIPVisionModelWithProjection = None,
224
+ feature_extractor: CLIPImageProcessor = None,
225
+ ):
226
+ super().__init__()
227
+ if isinstance(controlnet, (list, tuple)):
228
+ controlnet = FluxMultiControlNetModel(controlnet)
229
+
230
+ self.register_modules(
231
+ vae=vae,
232
+ text_encoder=text_encoder,
233
+ text_encoder_2=text_encoder_2,
234
+ tokenizer=tokenizer,
235
+ tokenizer_2=tokenizer_2,
236
+ transformer=transformer,
237
+ scheduler=scheduler,
238
+ controlnet=controlnet,
239
+ image_encoder=image_encoder,
240
+ feature_extractor=feature_extractor,
241
+ )
242
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
243
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
244
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
245
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
246
+ self.tokenizer_max_length = (
247
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
248
+ )
249
+ self.default_sample_size = 128
250
+
251
+ def _get_t5_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]] = None,
254
+ num_images_per_prompt: int = 1,
255
+ max_sequence_length: int = 512,
256
+ device: Optional[torch.device] = None,
257
+ dtype: Optional[torch.dtype] = None,
258
+ ):
259
+ device = device or self._execution_device
260
+ dtype = dtype or self.text_encoder.dtype
261
+
262
+ prompt = [prompt] if isinstance(prompt, str) else prompt
263
+ batch_size = len(prompt)
264
+
265
+ if isinstance(self, TextualInversionLoaderMixin):
266
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
267
+
268
+ text_inputs = self.tokenizer_2(
269
+ prompt,
270
+ padding="max_length",
271
+ max_length=max_sequence_length,
272
+ truncation=True,
273
+ return_length=False,
274
+ return_overflowing_tokens=False,
275
+ return_tensors="pt",
276
+ )
277
+ text_input_ids = text_inputs.input_ids
278
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
279
+
280
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
281
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
282
+ logger.warning(
283
+ "The following part of your input was truncated because `max_sequence_length` is set to "
284
+ f" {max_sequence_length} tokens: {removed_text}"
285
+ )
286
+
287
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
288
+
289
+ dtype = self.text_encoder_2.dtype
290
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
291
+
292
+ _, seq_len, _ = prompt_embeds.shape
293
+
294
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
295
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
296
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
297
+
298
+ return prompt_embeds
299
+
300
+ def _get_clip_prompt_embeds(
301
+ self,
302
+ prompt: Union[str, List[str]],
303
+ num_images_per_prompt: int = 1,
304
+ device: Optional[torch.device] = None,
305
+ ):
306
+ device = device or self._execution_device
307
+
308
+ prompt = [prompt] if isinstance(prompt, str) else prompt
309
+ batch_size = len(prompt)
310
+
311
+ if isinstance(self, TextualInversionLoaderMixin):
312
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
313
+
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="max_length",
317
+ max_length=self.tokenizer_max_length,
318
+ truncation=True,
319
+ return_overflowing_tokens=False,
320
+ return_length=False,
321
+ return_tensors="pt",
322
+ )
323
+
324
+ text_input_ids = text_inputs.input_ids
325
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
326
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
327
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
328
+ logger.warning(
329
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
330
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
331
+ )
332
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
333
+
334
+ # Use pooled output of CLIPTextModel
335
+ prompt_embeds = prompt_embeds.pooler_output
336
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
337
+
338
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
339
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
340
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341
+
342
+ return prompt_embeds
343
+
344
+ def encode_prompt(
345
+ self,
346
+ prompt: Union[str, List[str]],
347
+ prompt_2: Union[str, List[str]],
348
+ device: Optional[torch.device] = None,
349
+ num_images_per_prompt: int = 1,
350
+ prompt_embeds: Optional[torch.FloatTensor] = None,
351
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
352
+ max_sequence_length: int = 512,
353
+ lora_scale: Optional[float] = None,
354
+ ):
355
+ r"""
356
+
357
+ Args:
358
+ prompt (`str` or `List[str]`, *optional*):
359
+ prompt to be encoded
360
+ prompt_2 (`str` or `List[str]`, *optional*):
361
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
362
+ used in all text-encoders
363
+ device: (`torch.device`):
364
+ torch device
365
+ num_images_per_prompt (`int`):
366
+ number of images that should be generated per prompt
367
+ prompt_embeds (`torch.FloatTensor`, *optional*):
368
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
369
+ provided, text embeddings will be generated from `prompt` input argument.
370
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
372
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
373
+ clip_skip (`int`, *optional*):
374
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
375
+ the output of the pre-final layer will be used for computing the prompt embeddings.
376
+ lora_scale (`float`, *optional*):
377
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
378
+ """
379
+ device = device or self._execution_device
380
+
381
+ # set lora scale so that monkey patched LoRA
382
+ # function of text encoder can correctly access it
383
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
384
+ self._lora_scale = lora_scale
385
+
386
+ # dynamically adjust the LoRA scale
387
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
388
+ scale_lora_layers(self.text_encoder, lora_scale)
389
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
390
+ scale_lora_layers(self.text_encoder_2, lora_scale)
391
+
392
+ prompt = [prompt] if isinstance(prompt, str) else prompt
393
+
394
+ if prompt_embeds is None:
395
+ prompt_2 = prompt_2 or prompt
396
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
397
+
398
+ # We only use the pooled prompt output from the CLIPTextModel
399
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
400
+ prompt=prompt,
401
+ device=device,
402
+ num_images_per_prompt=num_images_per_prompt,
403
+ )
404
+ prompt_embeds = self._get_t5_prompt_embeds(
405
+ prompt=prompt_2,
406
+ num_images_per_prompt=num_images_per_prompt,
407
+ max_sequence_length=max_sequence_length,
408
+ device=device,
409
+ )
410
+
411
+ if self.text_encoder is not None:
412
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
413
+ # Retrieve the original scale by scaling back the LoRA layers
414
+ unscale_lora_layers(self.text_encoder, lora_scale)
415
+
416
+ if self.text_encoder_2 is not None:
417
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
418
+ # Retrieve the original scale by scaling back the LoRA layers
419
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
420
+
421
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
422
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
423
+
424
+ return prompt_embeds, pooled_prompt_embeds, text_ids
425
+
426
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
427
+ def encode_image(self, image, device, num_images_per_prompt):
428
+ dtype = next(self.image_encoder.parameters()).dtype
429
+
430
+ if not isinstance(image, torch.Tensor):
431
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
432
+
433
+ image = image.to(device=device, dtype=dtype)
434
+ image_embeds = self.image_encoder(image).image_embeds
435
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
436
+ return image_embeds
437
+
438
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
439
+ def prepare_ip_adapter_image_embeds(
440
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
441
+ ):
442
+ image_embeds = []
443
+ if ip_adapter_image_embeds is None:
444
+ if not isinstance(ip_adapter_image, list):
445
+ ip_adapter_image = [ip_adapter_image]
446
+
447
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
448
+ raise ValueError(
449
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
450
+ )
451
+
452
+ for single_ip_adapter_image in ip_adapter_image:
453
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
454
+ image_embeds.append(single_image_embeds[None, :])
455
+ else:
456
+ if not isinstance(ip_adapter_image_embeds, list):
457
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
458
+
459
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
460
+ raise ValueError(
461
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
462
+ )
463
+
464
+ for single_image_embeds in ip_adapter_image_embeds:
465
+ image_embeds.append(single_image_embeds)
466
+
467
+ ip_adapter_image_embeds = []
468
+ for single_image_embeds in image_embeds:
469
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
470
+ single_image_embeds = single_image_embeds.to(device=device)
471
+ ip_adapter_image_embeds.append(single_image_embeds)
472
+
473
+ return ip_adapter_image_embeds
474
+
475
+ def check_inputs(
476
+ self,
477
+ prompt,
478
+ prompt_2,
479
+ height,
480
+ width,
481
+ negative_prompt=None,
482
+ negative_prompt_2=None,
483
+ prompt_embeds=None,
484
+ negative_prompt_embeds=None,
485
+ pooled_prompt_embeds=None,
486
+ negative_pooled_prompt_embeds=None,
487
+ callback_on_step_end_tensor_inputs=None,
488
+ max_sequence_length=None,
489
+ ):
490
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
491
+ logger.warning(
492
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
493
+ )
494
+
495
+ if callback_on_step_end_tensor_inputs is not None and not all(
496
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
497
+ ):
498
+ raise ValueError(
499
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
500
+ )
501
+
502
+ if prompt is not None and prompt_embeds is not None:
503
+ raise ValueError(
504
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
505
+ " only forward one of the two."
506
+ )
507
+ elif prompt_2 is not None and prompt_embeds is not None:
508
+ raise ValueError(
509
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
510
+ " only forward one of the two."
511
+ )
512
+ elif prompt is None and prompt_embeds is None:
513
+ raise ValueError(
514
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
515
+ )
516
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
517
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
518
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
519
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
520
+
521
+ if negative_prompt is not None and negative_prompt_embeds is not None:
522
+ raise ValueError(
523
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
524
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
525
+ )
526
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
527
+ raise ValueError(
528
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
529
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
530
+ )
531
+
532
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
533
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
534
+ raise ValueError(
535
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
536
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
537
+ f" {negative_prompt_embeds.shape}."
538
+ )
539
+
540
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
541
+ raise ValueError(
542
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
543
+ )
544
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
545
+ raise ValueError(
546
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
547
+ )
548
+
549
+ if max_sequence_length is not None and max_sequence_length > 512:
550
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
551
+
552
+ @staticmethod
553
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
554
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
555
+ latent_image_ids = torch.zeros(height, width, 3)
556
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
557
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
558
+
559
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
560
+
561
+ latent_image_ids = latent_image_ids.reshape(
562
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
563
+ )
564
+
565
+ return latent_image_ids.to(device=device, dtype=dtype)
566
+
567
+ @staticmethod
568
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
569
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
570
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
571
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
572
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
573
+
574
+ return latents
575
+
576
+ @staticmethod
577
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
578
+ def _unpack_latents(latents, height, width, vae_scale_factor):
579
+ batch_size, num_patches, channels = latents.shape
580
+
581
+ # VAE applies 8x compression on images but we must also account for packing which requires
582
+ # latent height and width to be divisible by 2.
583
+ height = 2 * (int(height) // (vae_scale_factor * 2))
584
+ width = 2 * (int(width) // (vae_scale_factor * 2))
585
+
586
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
587
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
588
+
589
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
590
+
591
+ return latents
592
+
593
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
594
+ def prepare_latents(
595
+ self,
596
+ batch_size,
597
+ num_channels_latents,
598
+ height,
599
+ width,
600
+ dtype,
601
+ device,
602
+ generator,
603
+ latents=None,
604
+ ):
605
+ # VAE applies 8x compression on images but we must also account for packing which requires
606
+ # latent height and width to be divisible by 2.
607
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
608
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
609
+
610
+ shape = (batch_size, num_channels_latents, height, width)
611
+
612
+ if latents is not None:
613
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
614
+ return latents.to(device=device, dtype=dtype), latent_image_ids
615
+
616
+ if isinstance(generator, list) and len(generator) != batch_size:
617
+ raise ValueError(
618
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
619
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
620
+ )
621
+
622
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
623
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
624
+
625
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
626
+
627
+ return latents, latent_image_ids
628
+
629
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
630
+ def prepare_image(
631
+ self,
632
+ image,
633
+ width,
634
+ height,
635
+ batch_size,
636
+ num_images_per_prompt,
637
+ device,
638
+ dtype,
639
+ do_classifier_free_guidance=False,
640
+ guess_mode=False,
641
+ ):
642
+ if isinstance(image, torch.Tensor):
643
+ pass
644
+ else:
645
+ image = self.image_processor.preprocess(image, height=height, width=width)
646
+
647
+ image_batch_size = image.shape[0]
648
+
649
+ if image_batch_size == 1:
650
+ repeat_by = batch_size
651
+ else:
652
+ # image batch size is the same as prompt batch size
653
+ repeat_by = num_images_per_prompt
654
+
655
+ image = image.repeat_interleave(repeat_by, dim=0)
656
+
657
+ image = image.to(device=device, dtype=dtype)
658
+
659
+ if do_classifier_free_guidance and not guess_mode:
660
+ image = torch.cat([image] * 2)
661
+
662
+ return image
663
+
664
+ @property
665
+ def guidance_scale(self):
666
+ return self._guidance_scale
667
+
668
+ @property
669
+ def joint_attention_kwargs(self):
670
+ return self._joint_attention_kwargs
671
+
672
+ @property
673
+ def num_timesteps(self):
674
+ return self._num_timesteps
675
+
676
+ @property
677
+ def interrupt(self):
678
+ return self._interrupt
679
+
680
+ @torch.no_grad()
681
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
682
+ def __call__(
683
+ self,
684
+ prompt: Union[str, List[str]] = None,
685
+ prompt_2: Optional[Union[str, List[str]]] = None,
686
+ negative_prompt: Union[str, List[str]] = None,
687
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
688
+ true_cfg_scale: float = 1.0,
689
+ height: Optional[int] = None,
690
+ width: Optional[int] = None,
691
+ num_inference_steps: int = 28,
692
+ sigmas: Optional[List[float]] = None,
693
+ guidance_scale: float = 7.0,
694
+ control_guidance_start: Union[float, List[float]] = 0.0,
695
+ control_guidance_end: Union[float, List[float]] = 1.0,
696
+ control_image: PipelineImageInput = None,
697
+ control_mode: Optional[Union[int, List[int]]] = None,
698
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
699
+ num_images_per_prompt: Optional[int] = 1,
700
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
701
+ latents: Optional[torch.FloatTensor] = None,
702
+ prompt_embeds: Optional[torch.FloatTensor] = None,
703
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
704
+ ip_adapter_image: Optional[PipelineImageInput] = None,
705
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
706
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
707
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
708
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
709
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
710
+ output_type: Optional[str] = "pil",
711
+ return_dict: bool = True,
712
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
713
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
714
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
715
+ max_sequence_length: int = 512,
716
+ ):
717
+ r"""
718
+ Function invoked when calling the pipeline for generation.
719
+
720
+ Args:
721
+ prompt (`str` or `List[str]`, *optional*):
722
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
723
+ instead.
724
+ prompt_2 (`str` or `List[str]`, *optional*):
725
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
726
+ will be used instead
727
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
728
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
729
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
730
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
731
+ num_inference_steps (`int`, *optional*, defaults to 50):
732
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
733
+ expense of slower inference.
734
+ sigmas (`List[float]`, *optional*):
735
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
736
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
737
+ will be used.
738
+ guidance_scale (`float`, *optional*, defaults to 7.0):
739
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
740
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
741
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
742
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
743
+ usually at the expense of lower image quality.
744
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
745
+ The percentage of total steps at which the ControlNet starts applying.
746
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
747
+ The percentage of total steps at which the ControlNet stops applying.
748
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
749
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
750
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
751
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
752
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
753
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
754
+ images must be passed as a list such that each element of the list can be correctly batched for input
755
+ to a single ControlNet.
756
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
757
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
758
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
759
+ the corresponding scale as a list.
760
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
761
+ The control mode when applying ControlNet-Union.
762
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
763
+ The number of images to generate per prompt.
764
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
765
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
766
+ to make generation deterministic.
767
+ latents (`torch.FloatTensor`, *optional*):
768
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
769
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
770
+ tensor will ge generated by sampling using the supplied random `generator`.
771
+ prompt_embeds (`torch.FloatTensor`, *optional*):
772
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
773
+ provided, text embeddings will be generated from `prompt` input argument.
774
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
775
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
776
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
777
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
778
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
779
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
780
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
781
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
782
+ negative_ip_adapter_image:
783
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
784
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
785
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
786
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
787
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
788
+ output_type (`str`, *optional*, defaults to `"pil"`):
789
+ The output format of the generate image. Choose between
790
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
791
+ return_dict (`bool`, *optional*, defaults to `True`):
792
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
793
+ joint_attention_kwargs (`dict`, *optional*):
794
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
795
+ `self.processor` in
796
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
797
+ callback_on_step_end (`Callable`, *optional*):
798
+ A function that calls at the end of each denoising steps during the inference. The function is called
799
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
800
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
801
+ `callback_on_step_end_tensor_inputs`.
802
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
803
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
804
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
805
+ `._callback_tensor_inputs` attribute of your pipeline class.
806
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
807
+
808
+ Examples:
809
+
810
+ Returns:
811
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
812
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
813
+ images.
814
+ """
815
+
816
+ height = height or self.default_sample_size * self.vae_scale_factor
817
+ width = width or self.default_sample_size * self.vae_scale_factor
818
+
819
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
820
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
821
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
822
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
823
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
824
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
825
+ control_guidance_start, control_guidance_end = (
826
+ mult * [control_guidance_start],
827
+ mult * [control_guidance_end],
828
+ )
829
+
830
+ # 1. Check inputs. Raise error if not correct
831
+ self.check_inputs(
832
+ prompt,
833
+ prompt_2,
834
+ height,
835
+ width,
836
+ negative_prompt=negative_prompt,
837
+ negative_prompt_2=negative_prompt_2,
838
+ prompt_embeds=prompt_embeds,
839
+ negative_prompt_embeds=negative_prompt_embeds,
840
+ pooled_prompt_embeds=pooled_prompt_embeds,
841
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
842
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
843
+ max_sequence_length=max_sequence_length,
844
+ )
845
+
846
+ self._guidance_scale = guidance_scale
847
+ self._joint_attention_kwargs = joint_attention_kwargs
848
+ self._interrupt = False
849
+
850
+ # 2. Define call parameters
851
+ if prompt is not None and isinstance(prompt, str):
852
+ batch_size = 1
853
+ elif prompt is not None and isinstance(prompt, list):
854
+ batch_size = len(prompt)
855
+ else:
856
+ batch_size = prompt_embeds.shape[0]
857
+
858
+ device = self._execution_device
859
+ dtype = self.transformer.dtype
860
+
861
+ # 3. Prepare text embeddings
862
+ lora_scale = (
863
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
864
+ )
865
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
866
+ (
867
+ prompt_embeds,
868
+ pooled_prompt_embeds,
869
+ text_ids,
870
+ ) = self.encode_prompt(
871
+ prompt=prompt,
872
+ prompt_2=prompt_2,
873
+ prompt_embeds=prompt_embeds,
874
+ pooled_prompt_embeds=pooled_prompt_embeds,
875
+ device=device,
876
+ num_images_per_prompt=num_images_per_prompt,
877
+ max_sequence_length=max_sequence_length,
878
+ lora_scale=lora_scale,
879
+ )
880
+ if do_true_cfg:
881
+ (
882
+ negative_prompt_embeds,
883
+ negative_pooled_prompt_embeds,
884
+ _,
885
+ ) = self.encode_prompt(
886
+ prompt=negative_prompt,
887
+ prompt_2=negative_prompt_2,
888
+ prompt_embeds=negative_prompt_embeds,
889
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
890
+ device=device,
891
+ num_images_per_prompt=num_images_per_prompt,
892
+ max_sequence_length=max_sequence_length,
893
+ lora_scale=lora_scale,
894
+ )
895
+
896
+ # 3. Prepare control image
897
+ num_channels_latents = self.transformer.config.in_channels // 4
898
+ if isinstance(self.controlnet, FluxControlNetModel):
899
+ control_image = self.prepare_image(
900
+ image=control_image,
901
+ width=width,
902
+ height=height,
903
+ batch_size=batch_size * num_images_per_prompt,
904
+ num_images_per_prompt=num_images_per_prompt,
905
+ device=device,
906
+ dtype=self.vae.dtype,
907
+ )
908
+ height, width = control_image.shape[-2:]
909
+
910
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
911
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
912
+ if self.controlnet.input_hint_block is None:
913
+ # vae encode
914
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
915
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
916
+
917
+ # pack
918
+ height_control_image, width_control_image = control_image.shape[2:]
919
+ control_image = self._pack_latents(
920
+ control_image,
921
+ batch_size * num_images_per_prompt,
922
+ num_channels_latents,
923
+ height_control_image,
924
+ width_control_image,
925
+ )
926
+
927
+ # Here we ensure that `control_mode` has the same length as the control_image.
928
+ if control_mode is not None:
929
+ if not isinstance(control_mode, int):
930
+ raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
931
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
932
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
933
+
934
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
935
+ control_images = []
936
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
937
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
938
+ for i, control_image_ in enumerate(control_image):
939
+ control_image_ = self.prepare_image(
940
+ image=control_image_,
941
+ width=width,
942
+ height=height,
943
+ batch_size=batch_size * num_images_per_prompt,
944
+ num_images_per_prompt=num_images_per_prompt,
945
+ device=device,
946
+ dtype=self.vae.dtype,
947
+ )
948
+ height, width = control_image_.shape[-2:]
949
+
950
+ if self.controlnet.nets[0].input_hint_block is None:
951
+ # vae encode
952
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
953
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
954
+
955
+ # pack
956
+ height_control_image, width_control_image = control_image_.shape[2:]
957
+ control_image_ = self._pack_latents(
958
+ control_image_,
959
+ batch_size * num_images_per_prompt,
960
+ num_channels_latents,
961
+ height_control_image,
962
+ width_control_image,
963
+ )
964
+ control_images.append(control_image_)
965
+
966
+ control_image = control_images
967
+
968
+ # Here we ensure that `control_mode` has the same length as the control_image.
969
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
970
+ raise ValueError(
971
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
972
+ + " length as the number of controlnets (control images) specified"
973
+ )
974
+ if not isinstance(control_mode, list):
975
+ control_mode = [control_mode] * len(control_image)
976
+ # set control mode
977
+ control_modes = []
978
+ for cmode in control_mode:
979
+ if cmode is None:
980
+ cmode = -1
981
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
982
+ control_modes.append(control_mode)
983
+ control_mode = control_modes
984
+
985
+ # 4. Prepare latent variables
986
+ num_channels_latents = self.transformer.config.in_channels // 4
987
+ latents, latent_image_ids = self.prepare_latents(
988
+ batch_size * num_images_per_prompt,
989
+ num_channels_latents,
990
+ height,
991
+ width,
992
+ prompt_embeds.dtype,
993
+ device,
994
+ generator,
995
+ latents,
996
+ )
997
+
998
+ # 5. Prepare timesteps
999
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1000
+ image_seq_len = latents.shape[1]
1001
+ mu = calculate_shift(
1002
+ image_seq_len,
1003
+ self.scheduler.config.get("base_image_seq_len", 256),
1004
+ self.scheduler.config.get("max_image_seq_len", 4096),
1005
+ self.scheduler.config.get("base_shift", 0.5),
1006
+ self.scheduler.config.get("max_shift", 1.15),
1007
+ )
1008
+ timesteps, num_inference_steps = retrieve_timesteps(
1009
+ self.scheduler,
1010
+ num_inference_steps,
1011
+ device,
1012
+ sigmas=sigmas,
1013
+ mu=mu,
1014
+ )
1015
+
1016
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1017
+ self._num_timesteps = len(timesteps)
1018
+
1019
+ # 6. Create tensor stating which controlnets to keep
1020
+ controlnet_keep = []
1021
+ for i in range(len(timesteps)):
1022
+ keeps = [
1023
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1024
+ for s, e in zip(control_guidance_start, control_guidance_end)
1025
+ ]
1026
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1027
+
1028
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1029
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1030
+ ):
1031
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1032
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1033
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1034
+ ):
1035
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1036
+
1037
+ if self.joint_attention_kwargs is None:
1038
+ self._joint_attention_kwargs = {}
1039
+
1040
+ image_embeds = None
1041
+ negative_image_embeds = None
1042
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1043
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1044
+ ip_adapter_image,
1045
+ ip_adapter_image_embeds,
1046
+ device,
1047
+ batch_size * num_images_per_prompt,
1048
+ )
1049
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1050
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1051
+ negative_ip_adapter_image,
1052
+ negative_ip_adapter_image_embeds,
1053
+ device,
1054
+ batch_size * num_images_per_prompt,
1055
+ )
1056
+
1057
+ # 7. Denoising loop
1058
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1059
+ for i, t in enumerate(timesteps):
1060
+ if self.interrupt:
1061
+ continue
1062
+
1063
+ if image_embeds is not None:
1064
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1065
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1066
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1067
+
1068
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
1069
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
1070
+ else:
1071
+ use_guidance = self.controlnet.config.guidance_embeds
1072
+
1073
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
1074
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
1075
+
1076
+ if isinstance(controlnet_keep[i], list):
1077
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1078
+ else:
1079
+ controlnet_cond_scale = controlnet_conditioning_scale
1080
+ if isinstance(controlnet_cond_scale, list):
1081
+ controlnet_cond_scale = controlnet_cond_scale[0]
1082
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1083
+
1084
+ # controlnet
1085
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
1086
+ hidden_states=latents,
1087
+ controlnet_cond=control_image,
1088
+ controlnet_mode=control_mode,
1089
+ conditioning_scale=cond_scale,
1090
+ timestep=timestep / 1000,
1091
+ guidance=guidance,
1092
+ pooled_projections=pooled_prompt_embeds,
1093
+ encoder_hidden_states=prompt_embeds,
1094
+ txt_ids=text_ids,
1095
+ img_ids=latent_image_ids,
1096
+ joint_attention_kwargs=self.joint_attention_kwargs,
1097
+ return_dict=False,
1098
+ )
1099
+
1100
+ guidance = (
1101
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
1102
+ )
1103
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
1104
+
1105
+ noise_pred = self.transformer(
1106
+ hidden_states=latents,
1107
+ timestep=timestep / 1000,
1108
+ guidance=guidance,
1109
+ pooled_projections=pooled_prompt_embeds,
1110
+ encoder_hidden_states=prompt_embeds,
1111
+ controlnet_block_samples=controlnet_block_samples,
1112
+ controlnet_single_block_samples=controlnet_single_block_samples,
1113
+ txt_ids=text_ids,
1114
+ img_ids=latent_image_ids,
1115
+ joint_attention_kwargs=self.joint_attention_kwargs,
1116
+ return_dict=False,
1117
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
1118
+ )[0]
1119
+
1120
+ if do_true_cfg:
1121
+ if negative_image_embeds is not None:
1122
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1123
+ neg_noise_pred = self.transformer(
1124
+ hidden_states=latents,
1125
+ timestep=timestep / 1000,
1126
+ guidance=guidance,
1127
+ pooled_projections=negative_pooled_prompt_embeds,
1128
+ encoder_hidden_states=negative_prompt_embeds,
1129
+ controlnet_block_samples=controlnet_block_samples,
1130
+ controlnet_single_block_samples=controlnet_single_block_samples,
1131
+ txt_ids=text_ids,
1132
+ img_ids=latent_image_ids,
1133
+ joint_attention_kwargs=self.joint_attention_kwargs,
1134
+ return_dict=False,
1135
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
1136
+ )[0]
1137
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1138
+
1139
+ # compute the previous noisy sample x_t -> x_t-1
1140
+ latents_dtype = latents.dtype
1141
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1142
+
1143
+ if latents.dtype != latents_dtype:
1144
+ if torch.backends.mps.is_available():
1145
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1146
+ latents = latents.to(latents_dtype)
1147
+
1148
+ if callback_on_step_end is not None:
1149
+ callback_kwargs = {}
1150
+ for k in callback_on_step_end_tensor_inputs:
1151
+ callback_kwargs[k] = locals()[k]
1152
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1153
+
1154
+ latents = callback_outputs.pop("latents", latents)
1155
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1156
+ control_image = callback_outputs.pop("control_image", control_image)
1157
+
1158
+ # call the callback, if provided
1159
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1160
+ progress_bar.update()
1161
+
1162
+ if XLA_AVAILABLE:
1163
+ xm.mark_step()
1164
+
1165
+ if output_type == "latent":
1166
+ image = latents
1167
+
1168
+ else:
1169
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1170
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1171
+
1172
+ image = self.vae.decode(latents, return_dict=False)[0]
1173
+ image = self.image_processor.postprocess(image, output_type=output_type)
1174
+
1175
+ # Offload all models
1176
+ self.maybe_free_model_hooks()
1177
+
1178
+ if not return_dict:
1179
+ return (image,)
1180
+
1181
+ return FluxPipelineOutput(images=image)