ABDALLALSWAITI commited on
Commit
a662214
·
verified ·
1 Parent(s): a8285f8

Upload FP8 quantized model

Browse files
Files changed (1) hide show
  1. controlnet_flux.py +509 -0
controlnet_flux.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin
23
+ from diffusers.models.attention_processor import AttentionProcessor
24
+ from diffusers.models.modeling_utils import ModelMixin
25
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
27
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class FluxControlNetOutput(BaseOutput):
37
+ controlnet_block_samples: Tuple[torch.Tensor]
38
+ controlnet_single_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ patch_size: int = 1,
48
+ in_channels: int = 64,
49
+ num_layers: int = 19,
50
+ num_single_layers: int = 38,
51
+ attention_head_dim: int = 128,
52
+ num_attention_heads: int = 24,
53
+ joint_attention_dim: int = 4096,
54
+ pooled_projection_dim: int = 768,
55
+ guidance_embeds: bool = False,
56
+ axes_dims_rope: List[int] = [16, 56, 56],
57
+ num_mode: int = None,
58
+ conditioning_embedding_channels: int = None,
59
+ ):
60
+ super().__init__()
61
+ self.out_channels = in_channels
62
+ self.inner_dim = num_attention_heads * attention_head_dim
63
+
64
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
65
+ text_time_guidance_cls = (
66
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
67
+ )
68
+ self.time_text_embed = text_time_guidance_cls(
69
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
70
+ )
71
+
72
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
73
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
74
+
75
+ self.transformer_blocks = nn.ModuleList(
76
+ [
77
+ FluxTransformerBlock(
78
+ dim=self.inner_dim,
79
+ num_attention_heads=num_attention_heads,
80
+ attention_head_dim=attention_head_dim,
81
+ )
82
+ for i in range(num_layers)
83
+ ]
84
+ )
85
+
86
+ self.single_transformer_blocks = nn.ModuleList(
87
+ [
88
+ FluxSingleTransformerBlock(
89
+ dim=self.inner_dim,
90
+ num_attention_heads=num_attention_heads,
91
+ attention_head_dim=attention_head_dim,
92
+ )
93
+ for i in range(num_single_layers)
94
+ ]
95
+ )
96
+
97
+ # controlnet_blocks
98
+ self.controlnet_blocks = nn.ModuleList([])
99
+ for _ in range(len(self.transformer_blocks)):
100
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
101
+
102
+ self.controlnet_single_blocks = nn.ModuleList([])
103
+ for _ in range(len(self.single_transformer_blocks)):
104
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
105
+
106
+ self.union = num_mode is not None
107
+ if self.union:
108
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
109
+
110
+ if conditioning_embedding_channels is not None:
111
+ self.input_hint_block = ControlNetConditioningEmbedding(
112
+ conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
113
+ )
114
+ self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
115
+ else:
116
+ self.input_hint_block = None
117
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
118
+
119
+ self.gradient_checkpointing = False
120
+
121
+ @property
122
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
123
+ def attn_processors(self):
124
+ r"""
125
+ Returns:
126
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
127
+ indexed by its weight name.
128
+ """
129
+ # set recursively
130
+ processors = {}
131
+
132
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
133
+ if hasattr(module, "get_processor"):
134
+ processors[f"{name}.processor"] = module.get_processor()
135
+
136
+ for sub_name, child in module.named_children():
137
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
138
+
139
+ return processors
140
+
141
+ for name, module in self.named_children():
142
+ fn_recursive_add_processors(name, module, processors)
143
+
144
+ return processors
145
+
146
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
147
+ def set_attn_processor(self, processor):
148
+ r"""
149
+ Sets the attention processor to use to compute attention.
150
+
151
+ Parameters:
152
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
153
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
154
+ for **all** `Attention` layers.
155
+
156
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
157
+ processor. This is strongly recommended when setting trainable attention processors.
158
+
159
+ """
160
+ count = len(self.attn_processors.keys())
161
+
162
+ if isinstance(processor, dict) and len(processor) != count:
163
+ raise ValueError(
164
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
165
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
166
+ )
167
+
168
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
169
+ if hasattr(module, "set_processor"):
170
+ if not isinstance(processor, dict):
171
+ module.set_processor(processor)
172
+ else:
173
+ module.set_processor(processor.pop(f"{name}.processor"))
174
+
175
+ for sub_name, child in module.named_children():
176
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
177
+
178
+ for name, module in self.named_children():
179
+ fn_recursive_attn_processor(name, module, processor)
180
+
181
+ @classmethod
182
+ def from_transformer(
183
+ cls,
184
+ transformer,
185
+ num_layers: int = 4,
186
+ num_single_layers: int = 10,
187
+ attention_head_dim: int = 128,
188
+ num_attention_heads: int = 24,
189
+ load_weights_from_transformer=True,
190
+ ):
191
+ config = dict(transformer.config)
192
+ config["num_layers"] = num_layers
193
+ config["num_single_layers"] = num_single_layers
194
+ config["attention_head_dim"] = attention_head_dim
195
+ config["num_attention_heads"] = num_attention_heads
196
+
197
+ controlnet = cls.from_config(config)
198
+
199
+ if load_weights_from_transformer:
200
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
201
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
202
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
203
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
204
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
205
+ controlnet.single_transformer_blocks.load_state_dict(
206
+ transformer.single_transformer_blocks.state_dict(), strict=False
207
+ )
208
+
209
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
210
+
211
+ return controlnet
212
+
213
+ def forward(
214
+ self,
215
+ hidden_states: torch.Tensor,
216
+ controlnet_cond: torch.Tensor,
217
+ controlnet_mode: torch.Tensor = None,
218
+ conditioning_scale: float = 1.0,
219
+ encoder_hidden_states: torch.Tensor = None,
220
+ pooled_projections: torch.Tensor = None,
221
+ timestep: torch.LongTensor = None,
222
+ img_ids: torch.Tensor = None,
223
+ txt_ids: torch.Tensor = None,
224
+ guidance: torch.Tensor = None,
225
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
226
+ return_dict: bool = True,
227
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
228
+ """
229
+ The [`FluxTransformer2DModel`] forward method.
230
+
231
+ Args:
232
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
233
+ Input `hidden_states`.
234
+ controlnet_cond (`torch.Tensor`):
235
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
236
+ controlnet_mode (`torch.Tensor`):
237
+ The mode tensor of shape `(batch_size, 1)`.
238
+ conditioning_scale (`float`, defaults to `1.0`):
239
+ The scale factor for ControlNet outputs.
240
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
241
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
242
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
243
+ from the embeddings of input conditions.
244
+ timestep ( `torch.LongTensor`):
245
+ Used to indicate denoising step.
246
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
247
+ A list of tensors that if specified are added to the residuals of transformer blocks.
248
+ joint_attention_kwargs (`dict`, *optional*):
249
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
250
+ `self.processor` in
251
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
252
+ return_dict (`bool`, *optional*, defaults to `True`):
253
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
254
+ tuple.
255
+
256
+ Returns:
257
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
258
+ `tuple` where the first element is the sample tensor.
259
+ """
260
+ if joint_attention_kwargs is not None:
261
+ joint_attention_kwargs = joint_attention_kwargs.copy()
262
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
263
+ else:
264
+ lora_scale = 1.0
265
+
266
+ if USE_PEFT_BACKEND:
267
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
268
+ scale_lora_layers(self, lora_scale)
269
+ else:
270
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
271
+ logger.warning(
272
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
273
+ )
274
+ hidden_states = self.x_embedder(hidden_states)
275
+
276
+ if self.input_hint_block is not None:
277
+ controlnet_cond = self.input_hint_block(controlnet_cond)
278
+ batch_size, channels, height_pw, width_pw = controlnet_cond.shape
279
+ height = height_pw // self.config.patch_size
280
+ width = width_pw // self.config.patch_size
281
+ controlnet_cond = controlnet_cond.reshape(
282
+ batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
283
+ )
284
+ controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
285
+ controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
286
+ # add
287
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
288
+
289
+ timestep = timestep.to(hidden_states.dtype) * 1000
290
+ if guidance is not None:
291
+ guidance = guidance.to(hidden_states.dtype) * 1000
292
+ else:
293
+ guidance = None
294
+ temb = (
295
+ self.time_text_embed(timestep, pooled_projections)
296
+ if guidance is None
297
+ else self.time_text_embed(timestep, guidance, pooled_projections)
298
+ )
299
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
300
+
301
+ if txt_ids.ndim == 3:
302
+ logger.warning(
303
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
304
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
305
+ )
306
+ txt_ids = txt_ids[0]
307
+ if img_ids.ndim == 3:
308
+ logger.warning(
309
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
310
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
311
+ )
312
+ img_ids = img_ids[0]
313
+
314
+ if self.union:
315
+ # union mode
316
+ if controlnet_mode is None:
317
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
318
+ # union mode emb
319
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
320
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
321
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
322
+
323
+ ids = torch.cat((txt_ids, img_ids), dim=0)
324
+ image_rotary_emb = self.pos_embed(ids)
325
+
326
+ block_samples = ()
327
+ for index_block, block in enumerate(self.transformer_blocks):
328
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
329
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
330
+ block,
331
+ hidden_states,
332
+ encoder_hidden_states,
333
+ temb,
334
+ image_rotary_emb,
335
+ )
336
+
337
+ else:
338
+ encoder_hidden_states, hidden_states = block(
339
+ hidden_states=hidden_states,
340
+ encoder_hidden_states=encoder_hidden_states,
341
+ temb=temb,
342
+ image_rotary_emb=image_rotary_emb,
343
+ )
344
+ block_samples = block_samples + (hidden_states,)
345
+
346
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
347
+
348
+ single_block_samples = ()
349
+ for index_block, block in enumerate(self.single_transformer_blocks):
350
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
351
+ hidden_states = self._gradient_checkpointing_func(
352
+ block,
353
+ hidden_states,
354
+ temb,
355
+ image_rotary_emb,
356
+ )
357
+
358
+ else:
359
+ hidden_states = block(
360
+ hidden_states=hidden_states,
361
+ temb=temb,
362
+ image_rotary_emb=image_rotary_emb,
363
+ )
364
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
365
+
366
+ # controlnet block
367
+ controlnet_block_samples = ()
368
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
369
+ block_sample = controlnet_block(block_sample)
370
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
371
+
372
+ controlnet_single_block_samples = ()
373
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
374
+ single_block_sample = controlnet_block(single_block_sample)
375
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
376
+
377
+ # scaling
378
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
379
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
380
+
381
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
382
+ controlnet_single_block_samples = (
383
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
384
+ )
385
+
386
+ if USE_PEFT_BACKEND:
387
+ # remove `lora_scale` from each PEFT layer
388
+ unscale_lora_layers(self, lora_scale)
389
+
390
+ if not return_dict:
391
+ return (controlnet_block_samples, controlnet_single_block_samples)
392
+
393
+ return FluxControlNetOutput(
394
+ controlnet_block_samples=controlnet_block_samples,
395
+ controlnet_single_block_samples=controlnet_single_block_samples,
396
+ )
397
+
398
+
399
+ class FluxMultiControlNetModel(ModelMixin):
400
+ r"""
401
+ `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
402
+
403
+ This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
404
+ compatible with `FluxControlNetModel`.
405
+
406
+ Args:
407
+ controlnets (`List[FluxControlNetModel]`):
408
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
409
+ `FluxControlNetModel` as a list.
410
+ """
411
+
412
+ def __init__(self, controlnets):
413
+ super().__init__()
414
+ self.nets = nn.ModuleList(controlnets)
415
+
416
+ def forward(
417
+ self,
418
+ hidden_states: torch.FloatTensor,
419
+ controlnet_cond: List[torch.tensor],
420
+ controlnet_mode: List[torch.tensor],
421
+ conditioning_scale: List[float],
422
+ encoder_hidden_states: torch.Tensor = None,
423
+ pooled_projections: torch.Tensor = None,
424
+ timestep: torch.LongTensor = None,
425
+ img_ids: torch.Tensor = None,
426
+ txt_ids: torch.Tensor = None,
427
+ guidance: torch.Tensor = None,
428
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
429
+ return_dict: bool = True,
430
+ ) -> Union[FluxControlNetOutput, Tuple]:
431
+ # ControlNet-Union with multiple conditions
432
+ # only load one ControlNet for saving memories
433
+ if len(self.nets) == 1:
434
+ controlnet = self.nets[0]
435
+
436
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
437
+ block_samples, single_block_samples = controlnet(
438
+ hidden_states=hidden_states,
439
+ controlnet_cond=image,
440
+ controlnet_mode=mode[:, None],
441
+ conditioning_scale=scale,
442
+ timestep=timestep,
443
+ guidance=guidance,
444
+ pooled_projections=pooled_projections,
445
+ encoder_hidden_states=encoder_hidden_states,
446
+ txt_ids=txt_ids,
447
+ img_ids=img_ids,
448
+ joint_attention_kwargs=joint_attention_kwargs,
449
+ return_dict=return_dict,
450
+ )
451
+
452
+ # merge samples
453
+ if i == 0:
454
+ control_block_samples = block_samples
455
+ control_single_block_samples = single_block_samples
456
+ else:
457
+ if block_samples is not None and control_block_samples is not None:
458
+ control_block_samples = [
459
+ control_block_sample + block_sample
460
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
461
+ ]
462
+ if single_block_samples is not None and control_single_block_samples is not None:
463
+ control_single_block_samples = [
464
+ control_single_block_sample + block_sample
465
+ for control_single_block_sample, block_sample in zip(
466
+ control_single_block_samples, single_block_samples
467
+ )
468
+ ]
469
+
470
+ # Regular Multi-ControlNets
471
+ # load all ControlNets into memories
472
+ else:
473
+ for i, (image, mode, scale, controlnet) in enumerate(
474
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
475
+ ):
476
+ block_samples, single_block_samples = controlnet(
477
+ hidden_states=hidden_states,
478
+ controlnet_cond=image,
479
+ controlnet_mode=mode[:, None],
480
+ conditioning_scale=scale,
481
+ timestep=timestep,
482
+ guidance=guidance,
483
+ pooled_projections=pooled_projections,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ txt_ids=txt_ids,
486
+ img_ids=img_ids,
487
+ joint_attention_kwargs=joint_attention_kwargs,
488
+ return_dict=return_dict,
489
+ )
490
+
491
+ # merge samples
492
+ if i == 0:
493
+ control_block_samples = block_samples
494
+ control_single_block_samples = single_block_samples
495
+ else:
496
+ if block_samples is not None and control_block_samples is not None:
497
+ control_block_samples = [
498
+ control_block_sample + block_sample
499
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
500
+ ]
501
+ if single_block_samples is not None and control_single_block_samples is not None:
502
+ control_single_block_samples = [
503
+ control_single_block_sample + block_sample
504
+ for control_single_block_sample, block_sample in zip(
505
+ control_single_block_samples, single_block_samples
506
+ )
507
+ ]
508
+
509
+ return control_block_samples, control_single_block_samples