Diffusers
TalHach61 commited on
Commit
93db988
·
verified ·
1 Parent(s): a3ca6b2

Create controlnet_bria.py

Browse files
Files changed (1) hide show
  1. controlnet_bria.py +532 -0
controlnet_bria.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from transformer_bria import TimestepProjEmbeddings
23
+ from diffusers.models.controlnet import zero_module
24
+ from diffusers.utils.outputs import BaseOutput
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders import PeftAdapterMixin
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+
31
+ from transformer_bria import FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND
32
+ from diffusers.models.attention_processor import AttentionProcessor
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @dataclass
38
+ class BriaControlNetOutput(BaseOutput):
39
+ controlnet_block_samples: Tuple[torch.Tensor]
40
+ controlnet_single_block_samples: Tuple[torch.Tensor]
41
+
42
+
43
+ class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
44
+ _supports_gradient_checkpointing = True
45
+
46
+ @register_to_config
47
+ def __init__(
48
+ self,
49
+ patch_size: int = 1,
50
+ in_channels: int = 64,
51
+ num_layers: int = 19,
52
+ num_single_layers: int = 38,
53
+ attention_head_dim: int = 128,
54
+ num_attention_heads: int = 24,
55
+ joint_attention_dim: int = 4096,
56
+ pooled_projection_dim: int = 768,
57
+ guidance_embeds: bool = False,
58
+ axes_dims_rope: List[int] = [16, 56, 56],
59
+ num_mode: int = None,
60
+ rope_theta: int = 10000,
61
+ time_theta: int = 10000,
62
+ ):
63
+ super().__init__()
64
+ self.out_channels = in_channels
65
+ self.inner_dim = num_attention_heads * attention_head_dim
66
+
67
+ # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
68
+ self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
69
+
70
+ # text_time_guidance_cls = (
71
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
72
+ # )
73
+ # self.time_text_embed = text_time_guidance_cls(
74
+ # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
75
+ # )
76
+ self.time_embed = TimestepProjEmbeddings(
77
+ embedding_dim=self.inner_dim, time_theta=time_theta
78
+ )
79
+
80
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
81
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
82
+
83
+ self.transformer_blocks = nn.ModuleList(
84
+ [
85
+ FluxTransformerBlock(
86
+ dim=self.inner_dim,
87
+ num_attention_heads=num_attention_heads,
88
+ attention_head_dim=attention_head_dim,
89
+ )
90
+ for i in range(num_layers)
91
+ ]
92
+ )
93
+
94
+ self.single_transformer_blocks = nn.ModuleList(
95
+ [
96
+ FluxSingleTransformerBlock(
97
+ dim=self.inner_dim,
98
+ num_attention_heads=num_attention_heads,
99
+ attention_head_dim=attention_head_dim,
100
+ )
101
+ for i in range(num_single_layers)
102
+ ]
103
+ )
104
+
105
+ # controlnet_blocks
106
+ self.controlnet_blocks = nn.ModuleList([])
107
+ for _ in range(len(self.transformer_blocks)):
108
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
109
+
110
+ self.controlnet_single_blocks = nn.ModuleList([])
111
+ for _ in range(len(self.single_transformer_blocks)):
112
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
113
+
114
+ self.union = num_mode is not None and num_mode > 0
115
+ if self.union:
116
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
117
+
118
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
119
+
120
+ self.gradient_checkpointing = False
121
+
122
+ @property
123
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
124
+ def attn_processors(self):
125
+ r"""
126
+ Returns:
127
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
128
+ indexed by its weight name.
129
+ """
130
+ # set recursively
131
+ processors = {}
132
+
133
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
134
+ if hasattr(module, "get_processor"):
135
+ processors[f"{name}.processor"] = module.get_processor()
136
+
137
+ for sub_name, child in module.named_children():
138
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
139
+
140
+ return processors
141
+
142
+ for name, module in self.named_children():
143
+ fn_recursive_add_processors(name, module, processors)
144
+
145
+ return processors
146
+
147
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
148
+ def set_attn_processor(self, processor):
149
+ r"""
150
+ Sets the attention processor to use to compute attention.
151
+ Parameters:
152
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
153
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
154
+ for **all** `Attention` layers.
155
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
156
+ processor. This is strongly recommended when setting trainable attention processors.
157
+ """
158
+ count = len(self.attn_processors.keys())
159
+
160
+ if isinstance(processor, dict) and len(processor) != count:
161
+ raise ValueError(
162
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
163
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
164
+ )
165
+
166
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
167
+ if hasattr(module, "set_processor"):
168
+ if not isinstance(processor, dict):
169
+ module.set_processor(processor)
170
+ else:
171
+ module.set_processor(processor.pop(f"{name}.processor"))
172
+
173
+ for sub_name, child in module.named_children():
174
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
175
+
176
+ for name, module in self.named_children():
177
+ fn_recursive_attn_processor(name, module, processor)
178
+
179
+ def _set_gradient_checkpointing(self, module, value=False):
180
+ if hasattr(module, "gradient_checkpointing"):
181
+ module.gradient_checkpointing = value
182
+
183
+ @classmethod
184
+ def from_transformer(
185
+ cls,
186
+ transformer,
187
+ num_layers: int = 4,
188
+ num_single_layers: int = 10,
189
+ attention_head_dim: int = 128,
190
+ num_attention_heads: int = 24,
191
+ load_weights_from_transformer=True,
192
+ ):
193
+ config = transformer.config
194
+ config["num_layers"] = num_layers
195
+ config["num_single_layers"] = num_single_layers
196
+ config["attention_head_dim"] = attention_head_dim
197
+ config["num_attention_heads"] = num_attention_heads
198
+
199
+ controlnet = cls(**config)
200
+
201
+ if load_weights_from_transformer:
202
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
203
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
204
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
205
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
206
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
207
+ controlnet.single_transformer_blocks.load_state_dict(
208
+ transformer.single_transformer_blocks.state_dict(), strict=False
209
+ )
210
+
211
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
212
+
213
+ return controlnet
214
+
215
+ def forward(
216
+ self,
217
+ hidden_states: torch.Tensor,
218
+ controlnet_cond: torch.Tensor,
219
+ controlnet_mode: torch.Tensor = None,
220
+ conditioning_scale: float = 1.0,
221
+ encoder_hidden_states: torch.Tensor = None,
222
+ pooled_projections: torch.Tensor = None,
223
+ timestep: torch.LongTensor = None,
224
+ img_ids: torch.Tensor = None,
225
+ txt_ids: torch.Tensor = None,
226
+ guidance: torch.Tensor = None,
227
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
228
+ return_dict: bool = True,
229
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
230
+ """
231
+ The [`FluxTransformer2DModel`] forward method.
232
+ Args:
233
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
234
+ Input `hidden_states`.
235
+ controlnet_cond (`torch.Tensor`):
236
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
237
+ controlnet_mode (`torch.Tensor`):
238
+ The mode tensor of shape `(batch_size, 1)`.
239
+ conditioning_scale (`float`, defaults to `1.0`):
240
+ The scale factor for ControlNet outputs.
241
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
242
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
243
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
244
+ from the embeddings of input conditions.
245
+ timestep ( `torch.LongTensor`):
246
+ Used to indicate denoising step.
247
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
248
+ A list of tensors that if specified are added to the residuals of transformer blocks.
249
+ joint_attention_kwargs (`dict`, *optional*):
250
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
251
+ `self.processor` in
252
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
253
+ return_dict (`bool`, *optional*, defaults to `True`):
254
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
255
+ tuple.
256
+ Returns:
257
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
258
+ `tuple` where the first element is the sample tensor.
259
+ """
260
+ if guidance is not None:
261
+ print("guidance is not supported in BriaControlNetModel")
262
+ if pooled_projections is not None:
263
+ print("pooled_projections is not supported in BriaControlNetModel")
264
+ if joint_attention_kwargs is not None:
265
+ joint_attention_kwargs = joint_attention_kwargs.copy()
266
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
267
+ else:
268
+ lora_scale = 1.0
269
+
270
+ if USE_PEFT_BACKEND:
271
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
272
+ scale_lora_layers(self, lora_scale)
273
+ else:
274
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
275
+ logger.warning(
276
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
277
+ )
278
+ hidden_states = self.x_embedder(hidden_states)
279
+
280
+ # Convert controlnet_cond to the same dtype as the model weights
281
+ controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
282
+
283
+ # add
284
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
285
+
286
+ timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
287
+ if guidance is not None:
288
+ guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
289
+ else:
290
+ guidance = None
291
+
292
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
293
+
294
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
295
+
296
+ if txt_ids.ndim == 3:
297
+ logger.warning(
298
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
299
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
300
+ )
301
+ txt_ids = txt_ids[0]
302
+ if img_ids.ndim == 3:
303
+ logger.warning(
304
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
305
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
306
+ )
307
+ img_ids = img_ids[0]
308
+
309
+ if self.union:
310
+ # union mode
311
+ if controlnet_mode is None:
312
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
313
+
314
+ # Validate controlnet_mode values are within the valid range
315
+ if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
316
+ raise ValueError(f"`controlnet_mode` values must be in range [0, {self.num_mode-1}], but got values outside this range")
317
+
318
+ # union mode emb
319
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
320
+ if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
321
+ controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2])
322
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
323
+
324
+ txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
325
+ ids = torch.cat((txt_ids, img_ids), dim=0)
326
+ image_rotary_emb = self.pos_embed(ids)
327
+
328
+ block_samples = ()
329
+ for index_block, block in enumerate(self.transformer_blocks):
330
+ if self.training and self.gradient_checkpointing:
331
+
332
+ def create_custom_forward(module, return_dict=None):
333
+ def custom_forward(*inputs):
334
+ if return_dict is not None:
335
+ return module(*inputs, return_dict=return_dict)
336
+ else:
337
+ return module(*inputs)
338
+
339
+ return custom_forward
340
+
341
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
342
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
343
+ create_custom_forward(block),
344
+ hidden_states,
345
+ encoder_hidden_states,
346
+ temb,
347
+ image_rotary_emb,
348
+ **ckpt_kwargs,
349
+ )
350
+
351
+ else:
352
+ encoder_hidden_states, hidden_states = block(
353
+ hidden_states=hidden_states,
354
+ encoder_hidden_states=encoder_hidden_states,
355
+ temb=temb,
356
+ image_rotary_emb=image_rotary_emb,
357
+ )
358
+ block_samples = block_samples + (hidden_states,)
359
+
360
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
361
+
362
+ single_block_samples = ()
363
+ for index_block, block in enumerate(self.single_transformer_blocks):
364
+ if self.training and self.gradient_checkpointing:
365
+
366
+ def create_custom_forward(module, return_dict=None):
367
+ def custom_forward(*inputs):
368
+ if return_dict is not None:
369
+ return module(*inputs, return_dict=return_dict)
370
+ else:
371
+ return module(*inputs)
372
+
373
+ return custom_forward
374
+
375
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
376
+ hidden_states = torch.utils.checkpoint.checkpoint(
377
+ create_custom_forward(block),
378
+ hidden_states,
379
+ temb,
380
+ image_rotary_emb,
381
+ **ckpt_kwargs,
382
+ )
383
+
384
+ else:
385
+ hidden_states = block(
386
+ hidden_states=hidden_states,
387
+ temb=temb,
388
+ image_rotary_emb=image_rotary_emb,
389
+ )
390
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
391
+
392
+ # controlnet block
393
+ controlnet_block_samples = ()
394
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
395
+ block_sample = controlnet_block(block_sample)
396
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
397
+
398
+ controlnet_single_block_samples = ()
399
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
400
+ single_block_sample = controlnet_block(single_block_sample)
401
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
402
+
403
+ # scaling
404
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
405
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
406
+
407
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
408
+ controlnet_single_block_samples = (
409
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
410
+ )
411
+
412
+ if USE_PEFT_BACKEND:
413
+ # remove `lora_scale` from each PEFT layer
414
+ unscale_lora_layers(self, lora_scale)
415
+
416
+ if not return_dict:
417
+ return (controlnet_block_samples, controlnet_single_block_samples)
418
+
419
+ return BriaControlNetOutput(
420
+ controlnet_block_samples=controlnet_block_samples,
421
+ controlnet_single_block_samples=controlnet_single_block_samples,
422
+ )
423
+
424
+
425
+ class BriaMultiControlNetModel(ModelMixin):
426
+ r"""
427
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
428
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
429
+ compatible with `BriaControlNetModel`.
430
+ Args:
431
+ controlnets (`List[BriaControlNetModel]`):
432
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
433
+ `BriaControlNetModel` as a list.
434
+ """
435
+
436
+ def __init__(self, controlnets):
437
+ super().__init__()
438
+ self.nets = nn.ModuleList(controlnets)
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.FloatTensor,
443
+ controlnet_cond: List[torch.tensor],
444
+ controlnet_mode: List[torch.tensor],
445
+ conditioning_scale: List[float],
446
+ encoder_hidden_states: torch.Tensor = None,
447
+ pooled_projections: torch.Tensor = None,
448
+ timestep: torch.LongTensor = None,
449
+ img_ids: torch.Tensor = None,
450
+ txt_ids: torch.Tensor = None,
451
+ guidance: torch.Tensor = None,
452
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
453
+ return_dict: bool = True,
454
+ ) -> Union[BriaControlNetOutput, Tuple]:
455
+ # ControlNet-Union with multiple conditions
456
+ # only load one ControlNet for saving memories
457
+ if len(self.nets) == 1 and self.nets[0].union:
458
+ controlnet = self.nets[0]
459
+
460
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
461
+ block_samples, single_block_samples = controlnet(
462
+ hidden_states=hidden_states,
463
+ controlnet_cond=image,
464
+ controlnet_mode=mode[:, None],
465
+ conditioning_scale=scale,
466
+ timestep=timestep,
467
+ guidance=guidance,
468
+ pooled_projections=pooled_projections,
469
+ encoder_hidden_states=encoder_hidden_states,
470
+ txt_ids=txt_ids,
471
+ img_ids=img_ids,
472
+ joint_attention_kwargs=joint_attention_kwargs,
473
+ return_dict=return_dict,
474
+ )
475
+
476
+ # merge samples
477
+ if i == 0:
478
+ control_block_samples = block_samples
479
+ control_single_block_samples = single_block_samples
480
+ else:
481
+ control_block_samples = [
482
+ control_block_sample + block_sample
483
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
484
+ ]
485
+
486
+ control_single_block_samples = [
487
+ control_single_block_sample + block_sample
488
+ for control_single_block_sample, block_sample in zip(
489
+ control_single_block_samples, single_block_samples
490
+ )
491
+ ]
492
+
493
+ # Regular Multi-ControlNets
494
+ # load all ControlNets into memories
495
+ else:
496
+ for i, (image, mode, scale, controlnet) in enumerate(
497
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
498
+ ):
499
+ block_samples, single_block_samples = controlnet(
500
+ hidden_states=hidden_states,
501
+ controlnet_cond=image,
502
+ controlnet_mode=mode[:, None],
503
+ conditioning_scale=scale,
504
+ timestep=timestep,
505
+ guidance=guidance,
506
+ pooled_projections=pooled_projections,
507
+ encoder_hidden_states=encoder_hidden_states,
508
+ txt_ids=txt_ids,
509
+ img_ids=img_ids,
510
+ joint_attention_kwargs=joint_attention_kwargs,
511
+ return_dict=return_dict,
512
+ )
513
+
514
+ # merge samples
515
+ if i == 0:
516
+ control_block_samples = block_samples
517
+ control_single_block_samples = single_block_samples
518
+ else:
519
+ if block_samples is not None and control_block_samples is not None:
520
+ control_block_samples = [
521
+ control_block_sample + block_sample
522
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
523
+ ]
524
+ if single_block_samples is not None and control_single_block_samples is not None:
525
+ control_single_block_samples = [
526
+ control_single_block_sample + block_sample
527
+ for control_single_block_sample, block_sample in zip(
528
+ control_single_block_samples, single_block_samples
529
+ )
530
+ ]
531
+
532
+ return control_block_samples, control_single_block_samples