Diffusers
TalHach61 commited on
Commit
171612d
·
verified ·
1 Parent(s): d132afb

Create controlnet_bria.py

Browse files
Files changed (1) hide show
  1. controlnet_bria.py +539 -0
controlnet_bria.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from transformer_bria import TimestepProjEmbeddings
23
+ from diffusers.models.controlnet import zero_module
24
+ from diffusers.utils.outputs import BaseOutput
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.loaders import PeftAdapterMixin
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
29
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
+
31
+ from transformer_bria import FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND
32
+ from diffusers.models.attention_processor import AttentionProcessor
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @dataclass
38
+ class BriaControlNetOutput(BaseOutput):
39
+ controlnet_block_samples: Tuple[torch.Tensor]
40
+ controlnet_single_block_samples: Tuple[torch.Tensor]
41
+
42
+
43
+ class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
44
+ _supports_gradient_checkpointing = True
45
+
46
+ @register_to_config
47
+ def __init__(
48
+ self,
49
+ patch_size: int = 1,
50
+ in_channels: int = 64,
51
+ num_layers: int = 19,
52
+ num_single_layers: int = 38,
53
+ attention_head_dim: int = 128,
54
+ num_attention_heads: int = 24,
55
+ joint_attention_dim: int = 4096,
56
+ pooled_projection_dim: int = 768,
57
+ guidance_embeds: bool = False,
58
+ axes_dims_rope: List[int] = [16, 56, 56],
59
+ num_mode: int = None,
60
+ rope_theta: int = 10000,
61
+ time_theta: int = 10000,
62
+ ):
63
+ super().__init__()
64
+ self.out_channels = in_channels
65
+ self.inner_dim = num_attention_heads * attention_head_dim
66
+
67
+ # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
68
+ self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
69
+
70
+ # text_time_guidance_cls = (
71
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
72
+ # )
73
+ # self.time_text_embed = text_time_guidance_cls(
74
+ # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
75
+ # )
76
+ self.time_embed = TimestepProjEmbeddings(
77
+ embedding_dim=self.inner_dim, time_theta=time_theta
78
+ )
79
+
80
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
81
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
82
+
83
+ self.transformer_blocks = nn.ModuleList(
84
+ [
85
+ FluxTransformerBlock(
86
+ dim=self.inner_dim,
87
+ num_attention_heads=num_attention_heads,
88
+ attention_head_dim=attention_head_dim,
89
+ )
90
+ for i in range(num_layers)
91
+ ]
92
+ )
93
+
94
+ self.single_transformer_blocks = nn.ModuleList(
95
+ [
96
+ FluxSingleTransformerBlock(
97
+ dim=self.inner_dim,
98
+ num_attention_heads=num_attention_heads,
99
+ attention_head_dim=attention_head_dim,
100
+ )
101
+ for i in range(num_single_layers)
102
+ ]
103
+ )
104
+
105
+ # controlnet_blocks
106
+ self.controlnet_blocks = nn.ModuleList([])
107
+ for _ in range(len(self.transformer_blocks)):
108
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
109
+
110
+ self.controlnet_single_blocks = nn.ModuleList([])
111
+ for _ in range(len(self.single_transformer_blocks)):
112
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
113
+
114
+ self.union = num_mode is not None and num_mode > 0
115
+ if self.union:
116
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
117
+
118
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
119
+
120
+ self.gradient_checkpointing = False
121
+
122
+ @property
123
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
124
+ def attn_processors(self):
125
+ r"""
126
+ Returns:
127
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
128
+ indexed by its weight name.
129
+ """
130
+ # set recursively
131
+ processors = {}
132
+
133
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
134
+ if hasattr(module, "get_processor"):
135
+ processors[f"{name}.processor"] = module.get_processor()
136
+
137
+ for sub_name, child in module.named_children():
138
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
139
+
140
+ return processors
141
+
142
+ for name, module in self.named_children():
143
+ fn_recursive_add_processors(name, module, processors)
144
+
145
+ return processors
146
+
147
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
148
+ def set_attn_processor(self, processor):
149
+ r"""
150
+ Sets the attention processor to use to compute attention.
151
+
152
+ Parameters:
153
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
154
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
155
+ for **all** `Attention` layers.
156
+
157
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
158
+ processor. This is strongly recommended when setting trainable attention processors.
159
+
160
+ """
161
+ count = len(self.attn_processors.keys())
162
+
163
+ if isinstance(processor, dict) and len(processor) != count:
164
+ raise ValueError(
165
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
166
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
167
+ )
168
+
169
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
170
+ if hasattr(module, "set_processor"):
171
+ if not isinstance(processor, dict):
172
+ module.set_processor(processor)
173
+ else:
174
+ module.set_processor(processor.pop(f"{name}.processor"))
175
+
176
+ for sub_name, child in module.named_children():
177
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
178
+
179
+ for name, module in self.named_children():
180
+ fn_recursive_attn_processor(name, module, processor)
181
+
182
+ def _set_gradient_checkpointing(self, module, value=False):
183
+ if hasattr(module, "gradient_checkpointing"):
184
+ module.gradient_checkpointing = value
185
+
186
+ @classmethod
187
+ def from_transformer(
188
+ cls,
189
+ transformer,
190
+ num_layers: int = 4,
191
+ num_single_layers: int = 10,
192
+ attention_head_dim: int = 128,
193
+ num_attention_heads: int = 24,
194
+ load_weights_from_transformer=True,
195
+ ):
196
+ config = transformer.config
197
+ config["num_layers"] = num_layers
198
+ config["num_single_layers"] = num_single_layers
199
+ config["attention_head_dim"] = attention_head_dim
200
+ config["num_attention_heads"] = num_attention_heads
201
+
202
+ controlnet = cls(**config)
203
+
204
+ if load_weights_from_transformer:
205
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
206
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
207
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
208
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
209
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
210
+ controlnet.single_transformer_blocks.load_state_dict(
211
+ transformer.single_transformer_blocks.state_dict(), strict=False
212
+ )
213
+
214
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
215
+
216
+ return controlnet
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ controlnet_cond: torch.Tensor,
222
+ controlnet_mode: torch.Tensor = None,
223
+ conditioning_scale: float = 1.0,
224
+ encoder_hidden_states: torch.Tensor = None,
225
+ pooled_projections: torch.Tensor = None,
226
+ timestep: torch.LongTensor = None,
227
+ img_ids: torch.Tensor = None,
228
+ txt_ids: torch.Tensor = None,
229
+ guidance: torch.Tensor = None,
230
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ return_dict: bool = True,
232
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
233
+ """
234
+ The [`FluxTransformer2DModel`] forward method.
235
+
236
+ Args:
237
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
238
+ Input `hidden_states`.
239
+ controlnet_cond (`torch.Tensor`):
240
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
241
+ controlnet_mode (`torch.Tensor`):
242
+ The mode tensor of shape `(batch_size, 1)`.
243
+ conditioning_scale (`float`, defaults to `1.0`):
244
+ The scale factor for ControlNet outputs.
245
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
246
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
247
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
248
+ from the embeddings of input conditions.
249
+ timestep ( `torch.LongTensor`):
250
+ Used to indicate denoising step.
251
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
252
+ A list of tensors that if specified are added to the residuals of transformer blocks.
253
+ joint_attention_kwargs (`dict`, *optional*):
254
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
255
+ `self.processor` in
256
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
257
+ return_dict (`bool`, *optional*, defaults to `True`):
258
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
259
+ tuple.
260
+
261
+ Returns:
262
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
263
+ `tuple` where the first element is the sample tensor.
264
+ """
265
+ if guidance is not None:
266
+ print("guidance is not supported in BriaControlNetModel")
267
+ if pooled_projections is not None:
268
+ print("pooled_projections is not supported in BriaControlNetModel")
269
+ if joint_attention_kwargs is not None:
270
+ joint_attention_kwargs = joint_attention_kwargs.copy()
271
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
272
+ else:
273
+ lora_scale = 1.0
274
+
275
+ if USE_PEFT_BACKEND:
276
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
277
+ scale_lora_layers(self, lora_scale)
278
+ else:
279
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
280
+ logger.warning(
281
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
282
+ )
283
+ hidden_states = self.x_embedder(hidden_states)
284
+
285
+ # Convert controlnet_cond to the same dtype as the model weights
286
+ controlnet_cond = controlnet_cond.to(dtype=self.controlnet_x_embedder.weight.dtype)
287
+
288
+ # add
289
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
290
+
291
+ timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
292
+ if guidance is not None:
293
+ guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
294
+ else:
295
+ guidance = None
296
+
297
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
298
+
299
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
300
+
301
+ if txt_ids.ndim == 3:
302
+ logger.warning(
303
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
304
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
305
+ )
306
+ txt_ids = txt_ids[0]
307
+ if img_ids.ndim == 3:
308
+ logger.warning(
309
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
310
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
311
+ )
312
+ img_ids = img_ids[0]
313
+
314
+ if self.union:
315
+ # union mode
316
+ if controlnet_mode is None:
317
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
318
+
319
+ # Validate controlnet_mode values are within the valid range
320
+ if torch.any(controlnet_mode < 0) or torch.any(controlnet_mode >= self.num_mode):
321
+ raise ValueError(f"`controlnet_mode` values must be in range [0, {self.num_mode-1}], but got values outside this range")
322
+
323
+ # union mode emb
324
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
325
+ if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]: # duplicate mode emb for each batch
326
+ controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, encoder_hidden_states.shape[2])
327
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
328
+
329
+ txt_ids = torch.cat((txt_ids[0:1, :], txt_ids), dim=0)
330
+ ids = torch.cat((txt_ids, img_ids), dim=0)
331
+ image_rotary_emb = self.pos_embed(ids)
332
+
333
+ block_samples = ()
334
+ for index_block, block in enumerate(self.transformer_blocks):
335
+ if self.training and self.gradient_checkpointing:
336
+
337
+ def create_custom_forward(module, return_dict=None):
338
+ def custom_forward(*inputs):
339
+ if return_dict is not None:
340
+ return module(*inputs, return_dict=return_dict)
341
+ else:
342
+ return module(*inputs)
343
+
344
+ return custom_forward
345
+
346
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
347
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
348
+ create_custom_forward(block),
349
+ hidden_states,
350
+ encoder_hidden_states,
351
+ temb,
352
+ image_rotary_emb,
353
+ **ckpt_kwargs,
354
+ )
355
+
356
+ else:
357
+ encoder_hidden_states, hidden_states = block(
358
+ hidden_states=hidden_states,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ temb=temb,
361
+ image_rotary_emb=image_rotary_emb,
362
+ )
363
+ block_samples = block_samples + (hidden_states,)
364
+
365
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
366
+
367
+ single_block_samples = ()
368
+ for index_block, block in enumerate(self.single_transformer_blocks):
369
+ if self.training and self.gradient_checkpointing:
370
+
371
+ def create_custom_forward(module, return_dict=None):
372
+ def custom_forward(*inputs):
373
+ if return_dict is not None:
374
+ return module(*inputs, return_dict=return_dict)
375
+ else:
376
+ return module(*inputs)
377
+
378
+ return custom_forward
379
+
380
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
381
+ hidden_states = torch.utils.checkpoint.checkpoint(
382
+ create_custom_forward(block),
383
+ hidden_states,
384
+ temb,
385
+ image_rotary_emb,
386
+ **ckpt_kwargs,
387
+ )
388
+
389
+ else:
390
+ hidden_states = block(
391
+ hidden_states=hidden_states,
392
+ temb=temb,
393
+ image_rotary_emb=image_rotary_emb,
394
+ )
395
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
396
+
397
+ # controlnet block
398
+ controlnet_block_samples = ()
399
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
400
+ block_sample = controlnet_block(block_sample)
401
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
402
+
403
+ controlnet_single_block_samples = ()
404
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
405
+ single_block_sample = controlnet_block(single_block_sample)
406
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
407
+
408
+ # scaling
409
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
410
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
411
+
412
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
413
+ controlnet_single_block_samples = (
414
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
415
+ )
416
+
417
+ if USE_PEFT_BACKEND:
418
+ # remove `lora_scale` from each PEFT layer
419
+ unscale_lora_layers(self, lora_scale)
420
+
421
+ if not return_dict:
422
+ return (controlnet_block_samples, controlnet_single_block_samples)
423
+
424
+ return BriaControlNetOutput(
425
+ controlnet_block_samples=controlnet_block_samples,
426
+ controlnet_single_block_samples=controlnet_single_block_samples,
427
+ )
428
+
429
+
430
+ class BriaMultiControlNetModel(ModelMixin):
431
+ r"""
432
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
433
+
434
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
435
+ compatible with `BriaControlNetModel`.
436
+
437
+ Args:
438
+ controlnets (`List[BriaControlNetModel]`):
439
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
440
+ `BriaControlNetModel` as a list.
441
+ """
442
+
443
+ def __init__(self, controlnets):
444
+ super().__init__()
445
+ self.nets = nn.ModuleList(controlnets)
446
+
447
+ def forward(
448
+ self,
449
+ hidden_states: torch.FloatTensor,
450
+ controlnet_cond: List[torch.tensor],
451
+ controlnet_mode: List[torch.tensor],
452
+ conditioning_scale: List[float],
453
+ encoder_hidden_states: torch.Tensor = None,
454
+ pooled_projections: torch.Tensor = None,
455
+ timestep: torch.LongTensor = None,
456
+ img_ids: torch.Tensor = None,
457
+ txt_ids: torch.Tensor = None,
458
+ guidance: torch.Tensor = None,
459
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
460
+ return_dict: bool = True,
461
+ ) -> Union[BriaControlNetOutput, Tuple]:
462
+ # ControlNet-Union with multiple conditions
463
+ # only load one ControlNet for saving memories
464
+ if len(self.nets) == 1 and self.nets[0].union:
465
+ controlnet = self.nets[0]
466
+
467
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
468
+ block_samples, single_block_samples = controlnet(
469
+ hidden_states=hidden_states,
470
+ controlnet_cond=image,
471
+ controlnet_mode=mode[:, None],
472
+ conditioning_scale=scale,
473
+ timestep=timestep,
474
+ guidance=guidance,
475
+ pooled_projections=pooled_projections,
476
+ encoder_hidden_states=encoder_hidden_states,
477
+ txt_ids=txt_ids,
478
+ img_ids=img_ids,
479
+ joint_attention_kwargs=joint_attention_kwargs,
480
+ return_dict=return_dict,
481
+ )
482
+
483
+ # merge samples
484
+ if i == 0:
485
+ control_block_samples = block_samples
486
+ control_single_block_samples = single_block_samples
487
+ else:
488
+ control_block_samples = [
489
+ control_block_sample + block_sample
490
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
491
+ ]
492
+
493
+ control_single_block_samples = [
494
+ control_single_block_sample + block_sample
495
+ for control_single_block_sample, block_sample in zip(
496
+ control_single_block_samples, single_block_samples
497
+ )
498
+ ]
499
+
500
+ # Regular Multi-ControlNets
501
+ # load all ControlNets into memories
502
+ else:
503
+ for i, (image, mode, scale, controlnet) in enumerate(
504
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
505
+ ):
506
+ block_samples, single_block_samples = controlnet(
507
+ hidden_states=hidden_states,
508
+ controlnet_cond=image,
509
+ controlnet_mode=mode[:, None],
510
+ conditioning_scale=scale,
511
+ timestep=timestep,
512
+ guidance=guidance,
513
+ pooled_projections=pooled_projections,
514
+ encoder_hidden_states=encoder_hidden_states,
515
+ txt_ids=txt_ids,
516
+ img_ids=img_ids,
517
+ joint_attention_kwargs=joint_attention_kwargs,
518
+ return_dict=return_dict,
519
+ )
520
+
521
+ # merge samples
522
+ if i == 0:
523
+ control_block_samples = block_samples
524
+ control_single_block_samples = single_block_samples
525
+ else:
526
+ if block_samples is not None and control_block_samples is not None:
527
+ control_block_samples = [
528
+ control_block_sample + block_sample
529
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
530
+ ]
531
+ if single_block_samples is not None and control_single_block_samples is not None:
532
+ control_single_block_samples = [
533
+ control_single_block_sample + block_sample
534
+ for control_single_block_sample, block_sample in zip(
535
+ control_single_block_samples, single_block_samples
536
+ )
537
+ ]
538
+
539
+ return control_block_samples, control_single_block_samples