wanghaofan commited on
Commit
ee7b075
·
verified ·
1 Parent(s): a2350da

Upload 2 files

Browse files
Files changed (2) hide show
  1. controlnet_flux.py +509 -0
  2. pipeline_flux_controlnet.py +1181 -0
controlnet_flux.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin
23
+ from diffusers.models.attention_processor import AttentionProcessor
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
27
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class FluxControlNetOutput(BaseOutput):
37
+ controlnet_block_samples: Tuple[torch.Tensor]
38
+ controlnet_single_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ patch_size: int = 1,
48
+ in_channels: int = 64,
49
+ num_layers: int = 19,
50
+ num_single_layers: int = 38,
51
+ attention_head_dim: int = 128,
52
+ num_attention_heads: int = 24,
53
+ joint_attention_dim: int = 4096,
54
+ pooled_projection_dim: int = 768,
55
+ guidance_embeds: bool = False,
56
+ axes_dims_rope: List[int] = [16, 56, 56],
57
+ num_mode: int = None,
58
+ conditioning_embedding_channels: int = None,
59
+ ):
60
+ super().__init__()
61
+ self.out_channels = in_channels
62
+ self.inner_dim = num_attention_heads * attention_head_dim
63
+
64
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
65
+ text_time_guidance_cls = (
66
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
67
+ )
68
+ self.time_text_embed = text_time_guidance_cls(
69
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
70
+ )
71
+
72
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
73
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
74
+
75
+ self.transformer_blocks = nn.ModuleList(
76
+ [
77
+ FluxTransformerBlock(
78
+ dim=self.inner_dim,
79
+ num_attention_heads=num_attention_heads,
80
+ attention_head_dim=attention_head_dim,
81
+ )
82
+ for i in range(num_layers)
83
+ ]
84
+ )
85
+
86
+ self.single_transformer_blocks = nn.ModuleList(
87
+ [
88
+ FluxSingleTransformerBlock(
89
+ dim=self.inner_dim,
90
+ num_attention_heads=num_attention_heads,
91
+ attention_head_dim=attention_head_dim,
92
+ )
93
+ for i in range(num_single_layers)
94
+ ]
95
+ )
96
+
97
+ # controlnet_blocks
98
+ self.controlnet_blocks = nn.ModuleList([])
99
+ for _ in range(len(self.transformer_blocks)):
100
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
101
+
102
+ self.controlnet_single_blocks = nn.ModuleList([])
103
+ for _ in range(len(self.single_transformer_blocks)):
104
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
105
+
106
+ self.union = num_mode is not None
107
+ if self.union:
108
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
109
+
110
+ if conditioning_embedding_channels is not None:
111
+ self.input_hint_block = ControlNetConditioningEmbedding(
112
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
113
+ )
114
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
115
+ else:
116
+ self.input_hint_block = None
117
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
118
+
119
+ self.gradient_checkpointing = False
120
+
121
+ @property
122
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
123
+ def attn_processors(self):
124
+ r"""
125
+ Returns:
126
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
127
+ indexed by its weight name.
128
+ """
129
+ # set recursively
130
+ processors = {}
131
+
132
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
133
+ if hasattr(module, "get_processor"):
134
+ processors[f"{name}.processor"] = module.get_processor()
135
+
136
+ for sub_name, child in module.named_children():
137
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
138
+
139
+ return processors
140
+
141
+ for name, module in self.named_children():
142
+ fn_recursive_add_processors(name, module, processors)
143
+
144
+ return processors
145
+
146
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
147
+ def set_attn_processor(self, processor):
148
+ r"""
149
+ Sets the attention processor to use to compute attention.
150
+
151
+ Parameters:
152
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
153
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
154
+ for **all** `Attention` layers.
155
+
156
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
157
+ processor. This is strongly recommended when setting trainable attention processors.
158
+
159
+ """
160
+ count = len(self.attn_processors.keys())
161
+
162
+ if isinstance(processor, dict) and len(processor) != count:
163
+ raise ValueError(
164
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
165
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
166
+ )
167
+
168
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
169
+ if hasattr(module, "set_processor"):
170
+ if not isinstance(processor, dict):
171
+ module.set_processor(processor)
172
+ else:
173
+ module.set_processor(processor.pop(f"{name}.processor"))
174
+
175
+ for sub_name, child in module.named_children():
176
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_attn_processor(name, module, processor)
180
+
181
+ @classmethod
182
+ def from_transformer(
183
+ cls,
184
+ transformer,
185
+ num_layers: int = 4,
186
+ num_single_layers: int = 10,
187
+ attention_head_dim: int = 128,
188
+ num_attention_heads: int = 24,
189
+ load_weights_from_transformer=True,
190
+ ):
191
+ config = dict(transformer.config)
192
+ config["num_layers"] = num_layers
193
+ config["num_single_layers"] = num_single_layers
194
+ config["attention_head_dim"] = attention_head_dim
195
+ config["num_attention_heads"] = num_attention_heads
196
+
197
+ controlnet = cls.from_config(config)
198
+
199
+ if load_weights_from_transformer:
200
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
201
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
202
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
203
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
204
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
205
+ controlnet.single_transformer_blocks.load_state_dict(
206
+ transformer.single_transformer_blocks.state_dict(), strict=False
207
+ )
208
+
209
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
210
+
211
+ return controlnet
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ controlnet_cond: torch.Tensor,
217
+ controlnet_mode: torch.Tensor = None,
218
+ conditioning_scale: float = 1.0,
219
+ encoder_hidden_states: torch.Tensor = None,
220
+ pooled_projections: torch.Tensor = None,
221
+ timestep: torch.LongTensor = None,
222
+ img_ids: torch.Tensor = None,
223
+ txt_ids: torch.Tensor = None,
224
+ guidance: torch.Tensor = None,
225
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
226
+ return_dict: bool = True,
227
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
228
+ """
229
+ The [`FluxTransformer2DModel`] forward method.
230
+
231
+ Args:
232
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
233
+ Input `hidden_states`.
234
+ controlnet_cond (`torch.Tensor`):
235
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
236
+ controlnet_mode (`torch.Tensor`):
237
+ The mode tensor of shape `(batch_size, 1)`.
238
+ conditioning_scale (`float`, defaults to `1.0`):
239
+ The scale factor for ControlNet outputs.
240
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
241
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
242
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
243
+ from the embeddings of input conditions.
244
+ timestep ( `torch.LongTensor`):
245
+ Used to indicate denoising step.
246
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
247
+ A list of tensors that if specified are added to the residuals of transformer blocks.
248
+ joint_attention_kwargs (`dict`, *optional*):
249
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
250
+ `self.processor` in
251
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
254
+ tuple.
255
+
256
+ Returns:
257
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
258
+ `tuple` where the first element is the sample tensor.
259
+ """
260
+ if joint_attention_kwargs is not None:
261
+ joint_attention_kwargs = joint_attention_kwargs.copy()
262
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
263
+ else:
264
+ lora_scale = 1.0
265
+
266
+ if USE_PEFT_BACKEND:
267
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
268
+ scale_lora_layers(self, lora_scale)
269
+ else:
270
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
271
+ logger.warning(
272
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
273
+ )
274
+ hidden_states = self.x_embedder(hidden_states)
275
+
276
+ if self.input_hint_block is not None:
277
+ controlnet_cond = self.input_hint_block(controlnet_cond)
278
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
279
+ height = height_pw // self.config.patch_size
280
+ width = width_pw // self.config.patch_size
281
+ controlnet_cond = controlnet_cond.reshape(
282
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
283
+ )
284
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
285
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
286
+ # add
287
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
288
+
289
+ timestep = timestep.to(hidden_states.dtype) * 1000
290
+ if guidance is not None:
291
+ guidance = guidance.to(hidden_states.dtype) * 1000
292
+ else:
293
+ guidance = None
294
+ temb = (
295
+ self.time_text_embed(timestep, pooled_projections)
296
+ if guidance is None
297
+ else self.time_text_embed(timestep, guidance, pooled_projections)
298
+ )
299
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
300
+
301
+ if txt_ids.ndim == 3:
302
+ logger.warning(
303
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
304
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
305
+ )
306
+ txt_ids = txt_ids[0]
307
+ if img_ids.ndim == 3:
308
+ logger.warning(
309
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
310
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
311
+ )
312
+ img_ids = img_ids[0]
313
+
314
+ if self.union:
315
+ # union mode
316
+ if controlnet_mode is None:
317
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
318
+ # union mode emb
319
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
320
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
321
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
322
+
323
+ ids = torch.cat((txt_ids, img_ids), dim=0)
324
+ image_rotary_emb = self.pos_embed(ids)
325
+
326
+ block_samples = ()
327
+ for index_block, block in enumerate(self.transformer_blocks):
328
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
329
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
330
+ block,
331
+ hidden_states,
332
+ encoder_hidden_states,
333
+ temb,
334
+ image_rotary_emb,
335
+ )
336
+
337
+ else:
338
+ encoder_hidden_states, hidden_states = block(
339
+ hidden_states=hidden_states,
340
+ encoder_hidden_states=encoder_hidden_states,
341
+ temb=temb,
342
+ image_rotary_emb=image_rotary_emb,
343
+ )
344
+ block_samples = block_samples + (hidden_states,)
345
+
346
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
347
+
348
+ single_block_samples = ()
349
+ for index_block, block in enumerate(self.single_transformer_blocks):
350
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
351
+ hidden_states = self._gradient_checkpointing_func(
352
+ block,
353
+ hidden_states,
354
+ temb,
355
+ image_rotary_emb,
356
+ )
357
+
358
+ else:
359
+ hidden_states = block(
360
+ hidden_states=hidden_states,
361
+ temb=temb,
362
+ image_rotary_emb=image_rotary_emb,
363
+ )
364
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
365
+
366
+ # controlnet block
367
+ controlnet_block_samples = ()
368
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
369
+ block_sample = controlnet_block(block_sample)
370
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
371
+
372
+ controlnet_single_block_samples = ()
373
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
374
+ single_block_sample = controlnet_block(single_block_sample)
375
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
376
+
377
+ # scaling
378
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
379
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
380
+
381
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
382
+ controlnet_single_block_samples = (
383
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
384
+ )
385
+
386
+ if USE_PEFT_BACKEND:
387
+ # remove `lora_scale` from each PEFT layer
388
+ unscale_lora_layers(self, lora_scale)
389
+
390
+ if not return_dict:
391
+ return (controlnet_block_samples, controlnet_single_block_samples)
392
+
393
+ return FluxControlNetOutput(
394
+ controlnet_block_samples=controlnet_block_samples,
395
+ controlnet_single_block_samples=controlnet_single_block_samples,
396
+ )
397
+
398
+
399
+ class FluxMultiControlNetModel(ModelMixin):
400
+ r"""
401
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
402
+
403
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
404
+ compatible with `FluxControlNetModel`.
405
+
406
+ Args:
407
+ controlnets (`List[FluxControlNetModel]`):
408
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
409
+ `FluxControlNetModel` as a list.
410
+ """
411
+
412
+ def __init__(self, controlnets):
413
+ super().__init__()
414
+ self.nets = nn.ModuleList(controlnets)
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.FloatTensor,
419
+ controlnet_cond: List[torch.tensor],
420
+ controlnet_mode: List[torch.tensor],
421
+ conditioning_scale: List[float],
422
+ encoder_hidden_states: torch.Tensor = None,
423
+ pooled_projections: torch.Tensor = None,
424
+ timestep: torch.LongTensor = None,
425
+ img_ids: torch.Tensor = None,
426
+ txt_ids: torch.Tensor = None,
427
+ guidance: torch.Tensor = None,
428
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
429
+ return_dict: bool = True,
430
+ ) -> Union[FluxControlNetOutput, Tuple]:
431
+ # ControlNet-Union with multiple conditions
432
+ # only load one ControlNet for saving memories
433
+ if len(self.nets) == 1:
434
+ controlnet = self.nets[0]
435
+
436
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
437
+ block_samples, single_block_samples = controlnet(
438
+ hidden_states=hidden_states,
439
+ controlnet_cond=image,
440
+ controlnet_mode=mode[:, None],
441
+ conditioning_scale=scale,
442
+ timestep=timestep,
443
+ guidance=guidance,
444
+ pooled_projections=pooled_projections,
445
+ encoder_hidden_states=encoder_hidden_states,
446
+ txt_ids=txt_ids,
447
+ img_ids=img_ids,
448
+ joint_attention_kwargs=joint_attention_kwargs,
449
+ return_dict=return_dict,
450
+ )
451
+
452
+ # merge samples
453
+ if i == 0:
454
+ control_block_samples = block_samples
455
+ control_single_block_samples = single_block_samples
456
+ else:
457
+ if block_samples is not None and control_block_samples is not None:
458
+ control_block_samples = [
459
+ control_block_sample + block_sample
460
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
461
+ ]
462
+ if single_block_samples is not None and control_single_block_samples is not None:
463
+ control_single_block_samples = [
464
+ control_single_block_sample + block_sample
465
+ for control_single_block_sample, block_sample in zip(
466
+ control_single_block_samples, single_block_samples
467
+ )
468
+ ]
469
+
470
+ # Regular Multi-ControlNets
471
+ # load all ControlNets into memories
472
+ else:
473
+ for i, (image, mode, scale, controlnet) in enumerate(
474
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
475
+ ):
476
+ block_samples, single_block_samples = controlnet(
477
+ hidden_states=hidden_states,
478
+ controlnet_cond=image,
479
+ controlnet_mode=mode[:, None],
480
+ conditioning_scale=scale,
481
+ timestep=timestep,
482
+ guidance=guidance,
483
+ pooled_projections=pooled_projections,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ txt_ids=txt_ids,
486
+ img_ids=img_ids,
487
+ joint_attention_kwargs=joint_attention_kwargs,
488
+ return_dict=return_dict,
489
+ )
490
+
491
+ # merge samples
492
+ if i == 0:
493
+ control_block_samples = block_samples
494
+ control_single_block_samples = single_block_samples
495
+ else:
496
+ if block_samples is not None and control_block_samples is not None:
497
+ control_block_samples = [
498
+ control_block_sample + block_sample
499
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
500
+ ]
501
+ if single_block_samples is not None and control_single_block_samples is not None:
502
+ control_single_block_samples = [
503
+ control_single_block_sample + block_sample
504
+ for control_single_block_sample, block_sample in zip(
505
+ control_single_block_samples, single_block_samples
506
+ )
507
+ ]
508
+
509
+ return control_block_samples, control_single_block_samples
pipeline_flux_controlnet.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+
33
+ # from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
34
+ from controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
35
+
36
+ from diffusers.models.transformers import FluxTransformer2DModel
37
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
38
+ from diffusers.utils import (
39
+ USE_PEFT_BACKEND,
40
+ is_torch_xla_available,
41
+ logging,
42
+ replace_example_docstring,
43
+ scale_lora_layers,
44
+ unscale_lora_layers,
45
+ )
46
+ from diffusers.utils.torch_utils import randn_tensor
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
49
+
50
+
51
+ if is_torch_xla_available():
52
+ import torch_xla.core.xla_model as xm
53
+
54
+ XLA_AVAILABLE = True
55
+ else:
56
+ XLA_AVAILABLE = False
57
+
58
+
59
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
+
61
+ EXAMPLE_DOC_STRING = """
62
+ Examples:
63
+ ```py
64
+ >>> import torch
65
+ >>> from diffusers.utils import load_image
66
+ >>> from diffusers import FluxControlNetPipeline
67
+ >>> from diffusers import FluxControlNetModel
68
+
69
+ >>> base_model = "black-forest-labs/FLUX.1-dev"
70
+ >>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
71
+ >>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
72
+ >>> pipe = FluxControlNetPipeline.from_pretrained(
73
+ ... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
74
+ ... )
75
+ >>> pipe.to("cuda")
76
+ >>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
77
+ >>> prompt = "A girl in city, 25 years old, cool, futuristic"
78
+ >>> image = pipe(
79
+ ... prompt,
80
+ ... control_image=control_image,
81
+ ... control_guidance_start=0.2,
82
+ ... control_guidance_end=0.8,
83
+ ... controlnet_conditioning_scale=1.0,
84
+ ... num_inference_steps=28,
85
+ ... guidance_scale=3.5,
86
+ ... ).images[0]
87
+ >>> image.save("flux.png")
88
+ ```
89
+ """
90
+
91
+
92
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
93
+ def calculate_shift(
94
+ image_seq_len,
95
+ base_seq_len: int = 256,
96
+ max_seq_len: int = 4096,
97
+ base_shift: float = 0.5,
98
+ max_shift: float = 1.15,
99
+ ):
100
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
101
+ b = base_shift - m * base_seq_len
102
+ mu = image_seq_len * m + b
103
+ return mu
104
+
105
+
106
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
107
+ def retrieve_latents(
108
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
109
+ ):
110
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
111
+ return encoder_output.latent_dist.sample(generator)
112
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
113
+ return encoder_output.latent_dist.mode()
114
+ elif hasattr(encoder_output, "latents"):
115
+ return encoder_output.latents
116
+ else:
117
+ raise AttributeError("Could not access latents of provided encoder_output")
118
+
119
+
120
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
121
+ def retrieve_timesteps(
122
+ scheduler,
123
+ num_inference_steps: Optional[int] = None,
124
+ device: Optional[Union[str, torch.device]] = None,
125
+ timesteps: Optional[List[int]] = None,
126
+ sigmas: Optional[List[float]] = None,
127
+ **kwargs,
128
+ ):
129
+ r"""
130
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
131
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
132
+
133
+ Args:
134
+ scheduler (`SchedulerMixin`):
135
+ The scheduler to get timesteps from.
136
+ num_inference_steps (`int`):
137
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
138
+ must be `None`.
139
+ device (`str` or `torch.device`, *optional*):
140
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
141
+ timesteps (`List[int]`, *optional*):
142
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
143
+ `num_inference_steps` and `sigmas` must be `None`.
144
+ sigmas (`List[float]`, *optional*):
145
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
146
+ `num_inference_steps` and `timesteps` must be `None`.
147
+
148
+ Returns:
149
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
150
+ second element is the number of inference steps.
151
+ """
152
+ if timesteps is not None and sigmas is not None:
153
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
154
+ if timesteps is not None:
155
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
156
+ if not accepts_timesteps:
157
+ raise ValueError(
158
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
159
+ f" timestep schedules. Please check whether you are using the correct scheduler."
160
+ )
161
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
162
+ timesteps = scheduler.timesteps
163
+ num_inference_steps = len(timesteps)
164
+ elif sigmas is not None:
165
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
166
+ if not accept_sigmas:
167
+ raise ValueError(
168
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
169
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
170
+ )
171
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
172
+ timesteps = scheduler.timesteps
173
+ num_inference_steps = len(timesteps)
174
+ else:
175
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
176
+ timesteps = scheduler.timesteps
177
+ return timesteps, num_inference_steps
178
+
179
+
180
+ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin, FluxIPAdapterMixin):
181
+ r"""
182
+ The Flux pipeline for text-to-image generation.
183
+
184
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
185
+
186
+ Args:
187
+ transformer ([`FluxTransformer2DModel`]):
188
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
189
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
190
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
191
+ vae ([`AutoencoderKL`]):
192
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
193
+ text_encoder ([`CLIPTextModel`]):
194
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
195
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
196
+ text_encoder_2 ([`T5EncoderModel`]):
197
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
198
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
199
+ tokenizer (`CLIPTokenizer`):
200
+ Tokenizer of class
201
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
202
+ tokenizer_2 (`T5TokenizerFast`):
203
+ Second Tokenizer of class
204
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
205
+ """
206
+
207
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
208
+ _optional_components = ["image_encoder", "feature_extractor"]
209
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "control_image"]
210
+
211
+ def __init__(
212
+ self,
213
+ scheduler: FlowMatchEulerDiscreteScheduler,
214
+ vae: AutoencoderKL,
215
+ text_encoder: CLIPTextModel,
216
+ tokenizer: CLIPTokenizer,
217
+ text_encoder_2: T5EncoderModel,
218
+ tokenizer_2: T5TokenizerFast,
219
+ transformer: FluxTransformer2DModel,
220
+ controlnet: Union[
221
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
222
+ ],
223
+ image_encoder: CLIPVisionModelWithProjection = None,
224
+ feature_extractor: CLIPImageProcessor = None,
225
+ ):
226
+ super().__init__()
227
+ if isinstance(controlnet, (list, tuple)):
228
+ controlnet = FluxMultiControlNetModel(controlnet)
229
+
230
+ self.register_modules(
231
+ vae=vae,
232
+ text_encoder=text_encoder,
233
+ text_encoder_2=text_encoder_2,
234
+ tokenizer=tokenizer,
235
+ tokenizer_2=tokenizer_2,
236
+ transformer=transformer,
237
+ scheduler=scheduler,
238
+ controlnet=controlnet,
239
+ image_encoder=image_encoder,
240
+ feature_extractor=feature_extractor,
241
+ )
242
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
243
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
244
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
245
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
246
+ self.tokenizer_max_length = (
247
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
248
+ )
249
+ self.default_sample_size = 128
250
+
251
+ def _get_t5_prompt_embeds(
252
+ self,
253
+ prompt: Union[str, List[str]] = None,
254
+ num_images_per_prompt: int = 1,
255
+ max_sequence_length: int = 512,
256
+ device: Optional[torch.device] = None,
257
+ dtype: Optional[torch.dtype] = None,
258
+ ):
259
+ device = device or self._execution_device
260
+ dtype = dtype or self.text_encoder.dtype
261
+
262
+ prompt = [prompt] if isinstance(prompt, str) else prompt
263
+ batch_size = len(prompt)
264
+
265
+ if isinstance(self, TextualInversionLoaderMixin):
266
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
267
+
268
+ text_inputs = self.tokenizer_2(
269
+ prompt,
270
+ padding="max_length",
271
+ max_length=max_sequence_length,
272
+ truncation=True,
273
+ return_length=False,
274
+ return_overflowing_tokens=False,
275
+ return_tensors="pt",
276
+ )
277
+ text_input_ids = text_inputs.input_ids
278
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
279
+
280
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
281
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
282
+ logger.warning(
283
+ "The following part of your input was truncated because `max_sequence_length` is set to "
284
+ f" {max_sequence_length} tokens: {removed_text}"
285
+ )
286
+
287
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
288
+
289
+ dtype = self.text_encoder_2.dtype
290
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
291
+
292
+ _, seq_len, _ = prompt_embeds.shape
293
+
294
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
295
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
296
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
297
+
298
+ return prompt_embeds
299
+
300
+ def _get_clip_prompt_embeds(
301
+ self,
302
+ prompt: Union[str, List[str]],
303
+ num_images_per_prompt: int = 1,
304
+ device: Optional[torch.device] = None,
305
+ ):
306
+ device = device or self._execution_device
307
+
308
+ prompt = [prompt] if isinstance(prompt, str) else prompt
309
+ batch_size = len(prompt)
310
+
311
+ if isinstance(self, TextualInversionLoaderMixin):
312
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
313
+
314
+ text_inputs = self.tokenizer(
315
+ prompt,
316
+ padding="max_length",
317
+ max_length=self.tokenizer_max_length,
318
+ truncation=True,
319
+ return_overflowing_tokens=False,
320
+ return_length=False,
321
+ return_tensors="pt",
322
+ )
323
+
324
+ text_input_ids = text_inputs.input_ids
325
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
326
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
327
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
328
+ logger.warning(
329
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
330
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
331
+ )
332
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
333
+
334
+ # Use pooled output of CLIPTextModel
335
+ prompt_embeds = prompt_embeds.pooler_output
336
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
337
+
338
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
339
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
340
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
341
+
342
+ return prompt_embeds
343
+
344
+ def encode_prompt(
345
+ self,
346
+ prompt: Union[str, List[str]],
347
+ prompt_2: Union[str, List[str]],
348
+ device: Optional[torch.device] = None,
349
+ num_images_per_prompt: int = 1,
350
+ prompt_embeds: Optional[torch.FloatTensor] = None,
351
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
352
+ max_sequence_length: int = 512,
353
+ lora_scale: Optional[float] = None,
354
+ ):
355
+ r"""
356
+
357
+ Args:
358
+ prompt (`str` or `List[str]`, *optional*):
359
+ prompt to be encoded
360
+ prompt_2 (`str` or `List[str]`, *optional*):
361
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
362
+ used in all text-encoders
363
+ device: (`torch.device`):
364
+ torch device
365
+ num_images_per_prompt (`int`):
366
+ number of images that should be generated per prompt
367
+ prompt_embeds (`torch.FloatTensor`, *optional*):
368
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
369
+ provided, text embeddings will be generated from `prompt` input argument.
370
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
372
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
373
+ clip_skip (`int`, *optional*):
374
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
375
+ the output of the pre-final layer will be used for computing the prompt embeddings.
376
+ lora_scale (`float`, *optional*):
377
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
378
+ """
379
+ device = device or self._execution_device
380
+
381
+ # set lora scale so that monkey patched LoRA
382
+ # function of text encoder can correctly access it
383
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
384
+ self._lora_scale = lora_scale
385
+
386
+ # dynamically adjust the LoRA scale
387
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
388
+ scale_lora_layers(self.text_encoder, lora_scale)
389
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
390
+ scale_lora_layers(self.text_encoder_2, lora_scale)
391
+
392
+ prompt = [prompt] if isinstance(prompt, str) else prompt
393
+
394
+ if prompt_embeds is None:
395
+ prompt_2 = prompt_2 or prompt
396
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
397
+
398
+ # We only use the pooled prompt output from the CLIPTextModel
399
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
400
+ prompt=prompt,
401
+ device=device,
402
+ num_images_per_prompt=num_images_per_prompt,
403
+ )
404
+ prompt_embeds = self._get_t5_prompt_embeds(
405
+ prompt=prompt_2,
406
+ num_images_per_prompt=num_images_per_prompt,
407
+ max_sequence_length=max_sequence_length,
408
+ device=device,
409
+ )
410
+
411
+ if self.text_encoder is not None:
412
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
413
+ # Retrieve the original scale by scaling back the LoRA layers
414
+ unscale_lora_layers(self.text_encoder, lora_scale)
415
+
416
+ if self.text_encoder_2 is not None:
417
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
418
+ # Retrieve the original scale by scaling back the LoRA layers
419
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
420
+
421
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
422
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
423
+
424
+ return prompt_embeds, pooled_prompt_embeds, text_ids
425
+
426
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
427
+ def encode_image(self, image, device, num_images_per_prompt):
428
+ dtype = next(self.image_encoder.parameters()).dtype
429
+
430
+ if not isinstance(image, torch.Tensor):
431
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
432
+
433
+ image = image.to(device=device, dtype=dtype)
434
+ image_embeds = self.image_encoder(image).image_embeds
435
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
436
+ return image_embeds
437
+
438
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
439
+ def prepare_ip_adapter_image_embeds(
440
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
441
+ ):
442
+ image_embeds = []
443
+ if ip_adapter_image_embeds is None:
444
+ if not isinstance(ip_adapter_image, list):
445
+ ip_adapter_image = [ip_adapter_image]
446
+
447
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
448
+ raise ValueError(
449
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
450
+ )
451
+
452
+ for single_ip_adapter_image in ip_adapter_image:
453
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
454
+ image_embeds.append(single_image_embeds[None, :])
455
+ else:
456
+ if not isinstance(ip_adapter_image_embeds, list):
457
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
458
+
459
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
460
+ raise ValueError(
461
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
462
+ )
463
+
464
+ for single_image_embeds in ip_adapter_image_embeds:
465
+ image_embeds.append(single_image_embeds)
466
+
467
+ ip_adapter_image_embeds = []
468
+ for single_image_embeds in image_embeds:
469
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
470
+ single_image_embeds = single_image_embeds.to(device=device)
471
+ ip_adapter_image_embeds.append(single_image_embeds)
472
+
473
+ return ip_adapter_image_embeds
474
+
475
+ def check_inputs(
476
+ self,
477
+ prompt,
478
+ prompt_2,
479
+ height,
480
+ width,
481
+ negative_prompt=None,
482
+ negative_prompt_2=None,
483
+ prompt_embeds=None,
484
+ negative_prompt_embeds=None,
485
+ pooled_prompt_embeds=None,
486
+ negative_pooled_prompt_embeds=None,
487
+ callback_on_step_end_tensor_inputs=None,
488
+ max_sequence_length=None,
489
+ ):
490
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
491
+ logger.warning(
492
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
493
+ )
494
+
495
+ if callback_on_step_end_tensor_inputs is not None and not all(
496
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
497
+ ):
498
+ raise ValueError(
499
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
500
+ )
501
+
502
+ if prompt is not None and prompt_embeds is not None:
503
+ raise ValueError(
504
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
505
+ " only forward one of the two."
506
+ )
507
+ elif prompt_2 is not None and prompt_embeds is not None:
508
+ raise ValueError(
509
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
510
+ " only forward one of the two."
511
+ )
512
+ elif prompt is None and prompt_embeds is None:
513
+ raise ValueError(
514
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
515
+ )
516
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
517
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
518
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
519
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
520
+
521
+ if negative_prompt is not None and negative_prompt_embeds is not None:
522
+ raise ValueError(
523
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
524
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
525
+ )
526
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
527
+ raise ValueError(
528
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
529
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
530
+ )
531
+
532
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
533
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
534
+ raise ValueError(
535
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
536
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
537
+ f" {negative_prompt_embeds.shape}."
538
+ )
539
+
540
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
541
+ raise ValueError(
542
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
543
+ )
544
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
545
+ raise ValueError(
546
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
547
+ )
548
+
549
+ if max_sequence_length is not None and max_sequence_length > 512:
550
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
551
+
552
+ @staticmethod
553
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
554
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
555
+ latent_image_ids = torch.zeros(height, width, 3)
556
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
557
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
558
+
559
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
560
+
561
+ latent_image_ids = latent_image_ids.reshape(
562
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
563
+ )
564
+
565
+ return latent_image_ids.to(device=device, dtype=dtype)
566
+
567
+ @staticmethod
568
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
569
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
570
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
571
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
572
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
573
+
574
+ return latents
575
+
576
+ @staticmethod
577
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
578
+ def _unpack_latents(latents, height, width, vae_scale_factor):
579
+ batch_size, num_patches, channels = latents.shape
580
+
581
+ # VAE applies 8x compression on images but we must also account for packing which requires
582
+ # latent height and width to be divisible by 2.
583
+ height = 2 * (int(height) // (vae_scale_factor * 2))
584
+ width = 2 * (int(width) // (vae_scale_factor * 2))
585
+
586
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
587
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
588
+
589
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
590
+
591
+ return latents
592
+
593
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
594
+ def prepare_latents(
595
+ self,
596
+ batch_size,
597
+ num_channels_latents,
598
+ height,
599
+ width,
600
+ dtype,
601
+ device,
602
+ generator,
603
+ latents=None,
604
+ ):
605
+ # VAE applies 8x compression on images but we must also account for packing which requires
606
+ # latent height and width to be divisible by 2.
607
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
608
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
609
+
610
+ shape = (batch_size, num_channels_latents, height, width)
611
+
612
+ if latents is not None:
613
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
614
+ return latents.to(device=device, dtype=dtype), latent_image_ids
615
+
616
+ if isinstance(generator, list) and len(generator) != batch_size:
617
+ raise ValueError(
618
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
619
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
620
+ )
621
+
622
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
623
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
624
+
625
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
626
+
627
+ return latents, latent_image_ids
628
+
629
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
630
+ def prepare_image(
631
+ self,
632
+ image,
633
+ width,
634
+ height,
635
+ batch_size,
636
+ num_images_per_prompt,
637
+ device,
638
+ dtype,
639
+ do_classifier_free_guidance=False,
640
+ guess_mode=False,
641
+ ):
642
+ if isinstance(image, torch.Tensor):
643
+ pass
644
+ else:
645
+ image = self.image_processor.preprocess(image, height=height, width=width)
646
+
647
+ image_batch_size = image.shape[0]
648
+
649
+ if image_batch_size == 1:
650
+ repeat_by = batch_size
651
+ else:
652
+ # image batch size is the same as prompt batch size
653
+ repeat_by = num_images_per_prompt
654
+
655
+ image = image.repeat_interleave(repeat_by, dim=0)
656
+
657
+ image = image.to(device=device, dtype=dtype)
658
+
659
+ if do_classifier_free_guidance and not guess_mode:
660
+ image = torch.cat([image] * 2)
661
+
662
+ return image
663
+
664
+ @property
665
+ def guidance_scale(self):
666
+ return self._guidance_scale
667
+
668
+ @property
669
+ def joint_attention_kwargs(self):
670
+ return self._joint_attention_kwargs
671
+
672
+ @property
673
+ def num_timesteps(self):
674
+ return self._num_timesteps
675
+
676
+ @property
677
+ def interrupt(self):
678
+ return self._interrupt
679
+
680
+ @torch.no_grad()
681
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
682
+ def __call__(
683
+ self,
684
+ prompt: Union[str, List[str]] = None,
685
+ prompt_2: Optional[Union[str, List[str]]] = None,
686
+ negative_prompt: Union[str, List[str]] = None,
687
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
688
+ true_cfg_scale: float = 1.0,
689
+ height: Optional[int] = None,
690
+ width: Optional[int] = None,
691
+ num_inference_steps: int = 28,
692
+ sigmas: Optional[List[float]] = None,
693
+ guidance_scale: float = 7.0,
694
+ control_guidance_start: Union[float, List[float]] = 0.0,
695
+ control_guidance_end: Union[float, List[float]] = 1.0,
696
+ control_image: PipelineImageInput = None,
697
+ control_mode: Optional[Union[int, List[int]]] = None,
698
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
699
+ num_images_per_prompt: Optional[int] = 1,
700
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
701
+ latents: Optional[torch.FloatTensor] = None,
702
+ prompt_embeds: Optional[torch.FloatTensor] = None,
703
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
704
+ ip_adapter_image: Optional[PipelineImageInput] = None,
705
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
706
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
707
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
708
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
709
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
710
+ output_type: Optional[str] = "pil",
711
+ return_dict: bool = True,
712
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
713
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
714
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
715
+ max_sequence_length: int = 512,
716
+ ):
717
+ r"""
718
+ Function invoked when calling the pipeline for generation.
719
+
720
+ Args:
721
+ prompt (`str` or `List[str]`, *optional*):
722
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
723
+ instead.
724
+ prompt_2 (`str` or `List[str]`, *optional*):
725
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
726
+ will be used instead
727
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
728
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
729
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
730
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
731
+ num_inference_steps (`int`, *optional*, defaults to 50):
732
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
733
+ expense of slower inference.
734
+ sigmas (`List[float]`, *optional*):
735
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
736
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
737
+ will be used.
738
+ guidance_scale (`float`, *optional*, defaults to 7.0):
739
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
740
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
741
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
742
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
743
+ usually at the expense of lower image quality.
744
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
745
+ The percentage of total steps at which the ControlNet starts applying.
746
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
747
+ The percentage of total steps at which the ControlNet stops applying.
748
+ control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
749
+ `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
750
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
751
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
752
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
753
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
754
+ images must be passed as a list such that each element of the list can be correctly batched for input
755
+ to a single ControlNet.
756
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
757
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
758
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
759
+ the corresponding scale as a list.
760
+ control_mode (`int` or `List[int]`,, *optional*, defaults to None):
761
+ The control mode when applying ControlNet-Union.
762
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
763
+ The number of images to generate per prompt.
764
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
765
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
766
+ to make generation deterministic.
767
+ latents (`torch.FloatTensor`, *optional*):
768
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
769
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
770
+ tensor will ge generated by sampling using the supplied random `generator`.
771
+ prompt_embeds (`torch.FloatTensor`, *optional*):
772
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
773
+ provided, text embeddings will be generated from `prompt` input argument.
774
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
775
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
776
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
777
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
778
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
779
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
780
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
781
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
782
+ negative_ip_adapter_image:
783
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
784
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
785
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
786
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
787
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
788
+ output_type (`str`, *optional*, defaults to `"pil"`):
789
+ The output format of the generate image. Choose between
790
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
791
+ return_dict (`bool`, *optional*, defaults to `True`):
792
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
793
+ joint_attention_kwargs (`dict`, *optional*):
794
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
795
+ `self.processor` in
796
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
797
+ callback_on_step_end (`Callable`, *optional*):
798
+ A function that calls at the end of each denoising steps during the inference. The function is called
799
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
800
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
801
+ `callback_on_step_end_tensor_inputs`.
802
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
803
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
804
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
805
+ `._callback_tensor_inputs` attribute of your pipeline class.
806
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
807
+
808
+ Examples:
809
+
810
+ Returns:
811
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
812
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
813
+ images.
814
+ """
815
+
816
+ height = height or self.default_sample_size * self.vae_scale_factor
817
+ width = width or self.default_sample_size * self.vae_scale_factor
818
+
819
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
820
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
821
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
822
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
823
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
824
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
825
+ control_guidance_start, control_guidance_end = (
826
+ mult * [control_guidance_start],
827
+ mult * [control_guidance_end],
828
+ )
829
+
830
+ # 1. Check inputs. Raise error if not correct
831
+ self.check_inputs(
832
+ prompt,
833
+ prompt_2,
834
+ height,
835
+ width,
836
+ negative_prompt=negative_prompt,
837
+ negative_prompt_2=negative_prompt_2,
838
+ prompt_embeds=prompt_embeds,
839
+ negative_prompt_embeds=negative_prompt_embeds,
840
+ pooled_prompt_embeds=pooled_prompt_embeds,
841
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
842
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
843
+ max_sequence_length=max_sequence_length,
844
+ )
845
+
846
+ self._guidance_scale = guidance_scale
847
+ self._joint_attention_kwargs = joint_attention_kwargs
848
+ self._interrupt = False
849
+
850
+ # 2. Define call parameters
851
+ if prompt is not None and isinstance(prompt, str):
852
+ batch_size = 1
853
+ elif prompt is not None and isinstance(prompt, list):
854
+ batch_size = len(prompt)
855
+ else:
856
+ batch_size = prompt_embeds.shape[0]
857
+
858
+ device = self._execution_device
859
+ dtype = self.transformer.dtype
860
+
861
+ # 3. Prepare text embeddings
862
+ lora_scale = (
863
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
864
+ )
865
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
866
+ (
867
+ prompt_embeds,
868
+ pooled_prompt_embeds,
869
+ text_ids,
870
+ ) = self.encode_prompt(
871
+ prompt=prompt,
872
+ prompt_2=prompt_2,
873
+ prompt_embeds=prompt_embeds,
874
+ pooled_prompt_embeds=pooled_prompt_embeds,
875
+ device=device,
876
+ num_images_per_prompt=num_images_per_prompt,
877
+ max_sequence_length=max_sequence_length,
878
+ lora_scale=lora_scale,
879
+ )
880
+ if do_true_cfg:
881
+ (
882
+ negative_prompt_embeds,
883
+ negative_pooled_prompt_embeds,
884
+ _,
885
+ ) = self.encode_prompt(
886
+ prompt=negative_prompt,
887
+ prompt_2=negative_prompt_2,
888
+ prompt_embeds=negative_prompt_embeds,
889
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
890
+ device=device,
891
+ num_images_per_prompt=num_images_per_prompt,
892
+ max_sequence_length=max_sequence_length,
893
+ lora_scale=lora_scale,
894
+ )
895
+
896
+ # 3. Prepare control image
897
+ num_channels_latents = self.transformer.config.in_channels // 4
898
+ if isinstance(self.controlnet, FluxControlNetModel):
899
+ control_image = self.prepare_image(
900
+ image=control_image,
901
+ width=width,
902
+ height=height,
903
+ batch_size=batch_size * num_images_per_prompt,
904
+ num_images_per_prompt=num_images_per_prompt,
905
+ device=device,
906
+ dtype=self.vae.dtype,
907
+ )
908
+ height, width = control_image.shape[-2:]
909
+
910
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
911
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
912
+ if self.controlnet.input_hint_block is None:
913
+ # vae encode
914
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
915
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
916
+
917
+ # pack
918
+ height_control_image, width_control_image = control_image.shape[2:]
919
+ control_image = self._pack_latents(
920
+ control_image,
921
+ batch_size * num_images_per_prompt,
922
+ num_channels_latents,
923
+ height_control_image,
924
+ width_control_image,
925
+ )
926
+
927
+ # Here we ensure that `control_mode` has the same length as the control_image.
928
+ if control_mode is not None:
929
+ if not isinstance(control_mode, int):
930
+ raise ValueError(" For `FluxControlNet`, `control_mode` should be an `int` or `None`")
931
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
932
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
933
+
934
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
935
+ control_images = []
936
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
937
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
938
+ for i, control_image_ in enumerate(control_image):
939
+ control_image_ = self.prepare_image(
940
+ image=control_image_,
941
+ width=width,
942
+ height=height,
943
+ batch_size=batch_size * num_images_per_prompt,
944
+ num_images_per_prompt=num_images_per_prompt,
945
+ device=device,
946
+ dtype=self.vae.dtype,
947
+ )
948
+ height, width = control_image_.shape[-2:]
949
+
950
+ if self.controlnet.nets[0].input_hint_block is None:
951
+ # vae encode
952
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
953
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
954
+
955
+ # pack
956
+ height_control_image, width_control_image = control_image_.shape[2:]
957
+ control_image_ = self._pack_latents(
958
+ control_image_,
959
+ batch_size * num_images_per_prompt,
960
+ num_channels_latents,
961
+ height_control_image,
962
+ width_control_image,
963
+ )
964
+ control_images.append(control_image_)
965
+
966
+ control_image = control_images
967
+
968
+ # Here we ensure that `control_mode` has the same length as the control_image.
969
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
970
+ raise ValueError(
971
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
972
+ + " length as the number of controlnets (control images) specified"
973
+ )
974
+ if not isinstance(control_mode, list):
975
+ control_mode = [control_mode] * len(control_image)
976
+ # set control mode
977
+ control_modes = []
978
+ for cmode in control_mode:
979
+ if cmode is None:
980
+ cmode = -1
981
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
982
+ control_modes.append(control_mode)
983
+ control_mode = control_modes
984
+
985
+ # 4. Prepare latent variables
986
+ num_channels_latents = self.transformer.config.in_channels // 4
987
+ latents, latent_image_ids = self.prepare_latents(
988
+ batch_size * num_images_per_prompt,
989
+ num_channels_latents,
990
+ height,
991
+ width,
992
+ prompt_embeds.dtype,
993
+ device,
994
+ generator,
995
+ latents,
996
+ )
997
+
998
+ # 5. Prepare timesteps
999
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1000
+ image_seq_len = latents.shape[1]
1001
+ mu = calculate_shift(
1002
+ image_seq_len,
1003
+ self.scheduler.config.get("base_image_seq_len", 256),
1004
+ self.scheduler.config.get("max_image_seq_len", 4096),
1005
+ self.scheduler.config.get("base_shift", 0.5),
1006
+ self.scheduler.config.get("max_shift", 1.15),
1007
+ )
1008
+ timesteps, num_inference_steps = retrieve_timesteps(
1009
+ self.scheduler,
1010
+ num_inference_steps,
1011
+ device,
1012
+ sigmas=sigmas,
1013
+ mu=mu,
1014
+ )
1015
+
1016
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1017
+ self._num_timesteps = len(timesteps)
1018
+
1019
+ # 6. Create tensor stating which controlnets to keep
1020
+ controlnet_keep = []
1021
+ for i in range(len(timesteps)):
1022
+ keeps = [
1023
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1024
+ for s, e in zip(control_guidance_start, control_guidance_end)
1025
+ ]
1026
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1027
+
1028
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
1029
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
1030
+ ):
1031
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1032
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
1033
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
1034
+ ):
1035
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
1036
+
1037
+ if self.joint_attention_kwargs is None:
1038
+ self._joint_attention_kwargs = {}
1039
+
1040
+ image_embeds = None
1041
+ negative_image_embeds = None
1042
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1043
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1044
+ ip_adapter_image,
1045
+ ip_adapter_image_embeds,
1046
+ device,
1047
+ batch_size * num_images_per_prompt,
1048
+ )
1049
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
1050
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
1051
+ negative_ip_adapter_image,
1052
+ negative_ip_adapter_image_embeds,
1053
+ device,
1054
+ batch_size * num_images_per_prompt,
1055
+ )
1056
+
1057
+ # 7. Denoising loop
1058
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1059
+ for i, t in enumerate(timesteps):
1060
+ if self.interrupt:
1061
+ continue
1062
+
1063
+ if image_embeds is not None:
1064
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
1065
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1066
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1067
+
1068
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
1069
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
1070
+ else:
1071
+ use_guidance = self.controlnet.config.guidance_embeds
1072
+
1073
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
1074
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
1075
+
1076
+ if isinstance(controlnet_keep[i], list):
1077
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1078
+ else:
1079
+ controlnet_cond_scale = controlnet_conditioning_scale
1080
+ if isinstance(controlnet_cond_scale, list):
1081
+ controlnet_cond_scale = controlnet_cond_scale[0]
1082
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1083
+
1084
+ # controlnet
1085
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
1086
+ hidden_states=latents,
1087
+ controlnet_cond=control_image,
1088
+ controlnet_mode=control_mode,
1089
+ conditioning_scale=cond_scale,
1090
+ timestep=timestep / 1000,
1091
+ guidance=guidance,
1092
+ pooled_projections=pooled_prompt_embeds,
1093
+ encoder_hidden_states=prompt_embeds,
1094
+ txt_ids=text_ids,
1095
+ img_ids=latent_image_ids,
1096
+ joint_attention_kwargs=self.joint_attention_kwargs,
1097
+ return_dict=False,
1098
+ )
1099
+
1100
+ guidance = (
1101
+ torch.tensor([guidance_scale], device=device) if self.transformer.config.guidance_embeds else None
1102
+ )
1103
+ guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
1104
+
1105
+ noise_pred = self.transformer(
1106
+ hidden_states=latents,
1107
+ timestep=timestep / 1000,
1108
+ guidance=guidance,
1109
+ pooled_projections=pooled_prompt_embeds,
1110
+ encoder_hidden_states=prompt_embeds,
1111
+ controlnet_block_samples=controlnet_block_samples,
1112
+ controlnet_single_block_samples=controlnet_single_block_samples,
1113
+ txt_ids=text_ids,
1114
+ img_ids=latent_image_ids,
1115
+ joint_attention_kwargs=self.joint_attention_kwargs,
1116
+ return_dict=False,
1117
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
1118
+ )[0]
1119
+
1120
+ if do_true_cfg:
1121
+ if negative_image_embeds is not None:
1122
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
1123
+ neg_noise_pred = self.transformer(
1124
+ hidden_states=latents,
1125
+ timestep=timestep / 1000,
1126
+ guidance=guidance,
1127
+ pooled_projections=negative_pooled_prompt_embeds,
1128
+ encoder_hidden_states=negative_prompt_embeds,
1129
+ controlnet_block_samples=controlnet_block_samples,
1130
+ controlnet_single_block_samples=controlnet_single_block_samples,
1131
+ txt_ids=text_ids,
1132
+ img_ids=latent_image_ids,
1133
+ joint_attention_kwargs=self.joint_attention_kwargs,
1134
+ return_dict=False,
1135
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
1136
+ )[0]
1137
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
1138
+
1139
+ # compute the previous noisy sample x_t -> x_t-1
1140
+ latents_dtype = latents.dtype
1141
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1142
+
1143
+ if latents.dtype != latents_dtype:
1144
+ if torch.backends.mps.is_available():
1145
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1146
+ latents = latents.to(latents_dtype)
1147
+
1148
+ if callback_on_step_end is not None:
1149
+ callback_kwargs = {}
1150
+ for k in callback_on_step_end_tensor_inputs:
1151
+ callback_kwargs[k] = locals()[k]
1152
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1153
+
1154
+ latents = callback_outputs.pop("latents", latents)
1155
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1156
+ control_image = callback_outputs.pop("control_image", control_image)
1157
+
1158
+ # call the callback, if provided
1159
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1160
+ progress_bar.update()
1161
+
1162
+ if XLA_AVAILABLE:
1163
+ xm.mark_step()
1164
+
1165
+ if output_type == "latent":
1166
+ image = latents
1167
+
1168
+ else:
1169
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
1170
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1171
+
1172
+ image = self.vae.decode(latents, return_dict=False)[0]
1173
+ image = self.image_processor.postprocess(image, output_type=output_type)
1174
+
1175
+ # Offload all models
1176
+ self.maybe_free_model_hooks()
1177
+
1178
+ if not return_dict:
1179
+ return (image,)
1180
+
1181
+ return FluxPipelineOutput(images=image)