TalHach61 commited on
Commit
7b638a8
·
verified ·
1 Parent(s): 0eace44

Upload 5 files

Browse files
bria_utils.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Optional, List
2
+ import torch
3
+ from diffusers.utils import logging
4
+ from transformers import (
5
+ T5EncoderModel,
6
+ T5TokenizerFast,
7
+ )
8
+ import numpy as np
9
+
10
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11
+
12
+ def get_t5_prompt_embeds(
13
+ tokenizer: T5TokenizerFast ,
14
+ text_encoder: T5EncoderModel,
15
+ prompt: Union[str, List[str]] = None,
16
+ num_images_per_prompt: int = 1,
17
+ max_sequence_length: int = 128,
18
+ device: Optional[torch.device] = None,
19
+ ):
20
+ device = device or text_encoder.device
21
+
22
+ prompt = [prompt] if isinstance(prompt, str) else prompt
23
+ batch_size = len(prompt)
24
+
25
+ text_inputs = tokenizer(
26
+ prompt,
27
+ # padding="max_length",
28
+ max_length=max_sequence_length,
29
+ truncation=True,
30
+ add_special_tokens=True,
31
+ return_tensors="pt",
32
+ )
33
+ text_input_ids = text_inputs.input_ids
34
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
35
+
36
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
37
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
38
+ logger.warning(
39
+ "The following part of your input was truncated because `max_sequence_length` is set to "
40
+ f" {max_sequence_length} tokens: {removed_text}"
41
+ )
42
+
43
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
44
+
45
+ # Concat zeros to max_sequence
46
+ b, seq_len, dim = prompt_embeds.shape
47
+ if seq_len<max_sequence_length:
48
+ padding = torch.zeros((b,max_sequence_length-seq_len,dim),dtype=prompt_embeds.dtype,device=prompt_embeds.device)
49
+ prompt_embeds = torch.concat([prompt_embeds,padding],dim=1)
50
+
51
+ prompt_embeds = prompt_embeds.to(device=device)
52
+
53
+ _, seq_len, _ = prompt_embeds.shape
54
+
55
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
56
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
57
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
58
+
59
+ return prompt_embeds
60
+
61
+ # in order the get the same sigmas as in training and sample from them
62
+ def get_original_sigmas(num_train_timesteps=1000,num_inference_steps=1000):
63
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
64
+ sigmas = timesteps / num_train_timesteps
65
+
66
+ inds = [int(ind) for ind in np.linspace(0, num_train_timesteps-1, num_inference_steps)]
67
+ new_sigmas = sigmas[inds]
68
+ return new_sigmas
69
+
70
+ def is_ng_none(negative_prompt):
71
+ return negative_prompt is None or negative_prompt=='' or (isinstance(negative_prompt,list) and negative_prompt[0] is None) or (type(negative_prompt)==list and negative_prompt[0]=='')
72
+
controlnet_bria.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # type: ignore
2
+ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from transformer_bria import TimestepProjEmbeddings
23
+ from diffusers.models.controlnet import zero_module, BaseOutput
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.loaders import PeftAdapterMixin
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
29
+
30
+ # from transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock, EmbedND
31
+ from diffusers.models.transformers.transformer_flux import EmbedND, FluxSingleTransformerBlock, FluxTransformerBlock
32
+
33
+ from diffusers.models.attention_processor import AttentionProcessor
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ @dataclass
39
+ class BriaControlNetOutput(BaseOutput):
40
+ controlnet_block_samples: Tuple[torch.Tensor]
41
+ controlnet_single_block_samples: Tuple[torch.Tensor]
42
+
43
+
44
+ class BriaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
45
+ _supports_gradient_checkpointing = True
46
+
47
+ @register_to_config
48
+ def __init__(
49
+ self,
50
+ patch_size: int = 1,
51
+ in_channels: int = 64,
52
+ num_layers: int = 19,
53
+ num_single_layers: int = 38,
54
+ attention_head_dim: int = 128,
55
+ num_attention_heads: int = 24,
56
+ joint_attention_dim: int = 4096,
57
+ pooled_projection_dim: int = 768,
58
+ guidance_embeds: bool = False,
59
+ axes_dims_rope: List[int] = [16, 56, 56],
60
+ num_mode: int = None,
61
+ rope_theta: int = 10000,
62
+ time_theta: int = 10000,
63
+ ):
64
+ super().__init__()
65
+ self.out_channels = in_channels
66
+ self.inner_dim = num_attention_heads * attention_head_dim
67
+
68
+ # self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
69
+ self.pos_embed = EmbedND(dim=self.inner_dim, theta=rope_theta, axes_dim=axes_dims_rope)
70
+
71
+ # text_time_guidance_cls = (
72
+ # CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
73
+ # )
74
+ # self.time_text_embed = text_time_guidance_cls(
75
+ # embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
76
+ # )
77
+ self.time_embed = TimestepProjEmbeddings(
78
+ embedding_dim=self.inner_dim, max_period = 10000 #,
79
+ )
80
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
81
+ self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
82
+
83
+ self.transformer_blocks = nn.ModuleList(
84
+ [
85
+ FluxTransformerBlock(
86
+ dim=self.inner_dim,
87
+ num_attention_heads=num_attention_heads,
88
+ attention_head_dim=attention_head_dim,
89
+ )
90
+ for i in range(num_layers)
91
+ ]
92
+ )
93
+
94
+ self.single_transformer_blocks = nn.ModuleList(
95
+ [
96
+ FluxSingleTransformerBlock(
97
+ dim=self.inner_dim,
98
+ num_attention_heads=num_attention_heads,
99
+ attention_head_dim=attention_head_dim,
100
+ )
101
+ for i in range(num_single_layers)
102
+ ]
103
+ )
104
+
105
+ # controlnet_blocks
106
+ self.controlnet_blocks = nn.ModuleList([])
107
+ for _ in range(len(self.transformer_blocks)):
108
+ self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
109
+
110
+ self.controlnet_single_blocks = nn.ModuleList([])
111
+ for _ in range(len(self.single_transformer_blocks)):
112
+ self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
113
+
114
+ self.union = num_mode is not None and num_mode > 0
115
+ if self.union:
116
+ self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
117
+
118
+ self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
119
+
120
+ self.gradient_checkpointing = False
121
+
122
+ @property
123
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
124
+ def attn_processors(self):
125
+ r"""
126
+ Returns:
127
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
128
+ indexed by its weight name.
129
+ """
130
+ # set recursively
131
+ processors = {}
132
+
133
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
134
+ if hasattr(module, "get_processor"):
135
+ processors[f"{name}.processor"] = module.get_processor()
136
+
137
+ for sub_name, child in module.named_children():
138
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
139
+
140
+ return processors
141
+
142
+ for name, module in self.named_children():
143
+ fn_recursive_add_processors(name, module, processors)
144
+
145
+ return processors
146
+
147
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
148
+ def set_attn_processor(self, processor):
149
+ r"""
150
+ Sets the attention processor to use to compute attention.
151
+
152
+ Parameters:
153
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
154
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
155
+ for **all** `Attention` layers.
156
+
157
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
158
+ processor. This is strongly recommended when setting trainable attention processors.
159
+
160
+ """
161
+ count = len(self.attn_processors.keys())
162
+
163
+ if isinstance(processor, dict) and len(processor) != count:
164
+ raise ValueError(
165
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
166
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
167
+ )
168
+
169
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
170
+ if hasattr(module, "set_processor"):
171
+ if not isinstance(processor, dict):
172
+ module.set_processor(processor)
173
+ else:
174
+ module.set_processor(processor.pop(f"{name}.processor"))
175
+
176
+ for sub_name, child in module.named_children():
177
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
178
+
179
+ for name, module in self.named_children():
180
+ fn_recursive_attn_processor(name, module, processor)
181
+
182
+ def _set_gradient_checkpointing(self, module, value=False):
183
+ if hasattr(module, "gradient_checkpointing"):
184
+ module.gradient_checkpointing = value
185
+
186
+ @classmethod
187
+ def from_transformer(
188
+ cls,
189
+ transformer,
190
+ num_layers: int = 4,
191
+ num_single_layers: int = 10,
192
+ attention_head_dim: int = 128,
193
+ num_attention_heads: int = 24,
194
+ load_weights_from_transformer=True,
195
+ ):
196
+ config = transformer.config
197
+ config["num_layers"] = num_layers
198
+ config["num_single_layers"] = num_single_layers
199
+ config["attention_head_dim"] = attention_head_dim
200
+ config["num_attention_heads"] = num_attention_heads
201
+
202
+ controlnet = cls(**config)
203
+
204
+ if load_weights_from_transformer:
205
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
206
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
207
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
208
+ controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
209
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
210
+ controlnet.single_transformer_blocks.load_state_dict(
211
+ transformer.single_transformer_blocks.state_dict(), strict=False
212
+ )
213
+
214
+ controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
215
+
216
+ return controlnet
217
+
218
+ def forward(
219
+ self,
220
+ hidden_states: torch.Tensor,
221
+ controlnet_cond: torch.Tensor,
222
+ controlnet_mode: torch.Tensor = None,
223
+ conditioning_scale: float = 1.0,
224
+ encoder_hidden_states: torch.Tensor = None,
225
+ pooled_projections: torch.Tensor = None,
226
+ timestep: torch.LongTensor = None,
227
+ img_ids: torch.Tensor = None,
228
+ txt_ids: torch.Tensor = None,
229
+ guidance: torch.Tensor = None,
230
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
231
+ return_dict: bool = True,
232
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
233
+ """
234
+ The [`FluxTransformer2DModel`] forward method.
235
+
236
+ Args:
237
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
238
+ Input `hidden_states`.
239
+ controlnet_cond (`torch.Tensor`):
240
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
241
+ controlnet_mode (`torch.Tensor`):
242
+ The mode tensor of shape `(batch_size, 1)`.
243
+ conditioning_scale (`float`, defaults to `1.0`):
244
+ The scale factor for ControlNet outputs.
245
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
246
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
247
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
248
+ from the embeddings of input conditions.
249
+ timestep ( `torch.LongTensor`):
250
+ Used to indicate denoising step.
251
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
252
+ A list of tensors that if specified are added to the residuals of transformer blocks.
253
+ joint_attention_kwargs (`dict`, *optional*):
254
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
255
+ `self.processor` in
256
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
257
+ return_dict (`bool`, *optional*, defaults to `True`):
258
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
259
+ tuple.
260
+
261
+ Returns:
262
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
263
+ `tuple` where the first element is the sample tensor.
264
+ """
265
+ if guidance is not None:
266
+ print("guidance is not supported in BriaControlNetModel")
267
+ if pooled_projections is not None:
268
+ print("pooled_projections is not supported in BriaControlNetModel")
269
+ if joint_attention_kwargs is not None:
270
+ joint_attention_kwargs = joint_attention_kwargs.copy()
271
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
272
+ else:
273
+ lora_scale = 1.0
274
+
275
+ if USE_PEFT_BACKEND:
276
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
277
+ scale_lora_layers(self, lora_scale)
278
+ else:
279
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
280
+ logger.warning(
281
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
282
+ )
283
+ hidden_states = self.x_embedder(hidden_states)
284
+
285
+ # add
286
+ hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
287
+
288
+ timestep = timestep.to(hidden_states.dtype) # Original code was * 1000
289
+ if guidance is not None:
290
+ guidance = guidance.to(hidden_states.dtype) # Original code was * 1000
291
+ else:
292
+ guidance = None
293
+ # temb = (
294
+ # self.time_text_embed(timestep, pooled_projections)
295
+ # if guidance is None
296
+ # else self.time_text_embed(timestep, guidance, pooled_projections)
297
+ # )
298
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
299
+
300
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
301
+
302
+ if self.union:
303
+ # union mode
304
+ if controlnet_mode is None:
305
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
306
+ # union mode emb
307
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
308
+ if controlnet_mode_emb.shape[0] < encoder_hidden_states.shape[0]:
309
+ controlnet_mode_emb = controlnet_mode_emb.expand(encoder_hidden_states.shape[0], 1, 2048)
310
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
311
+ txt_ids = torch.cat((txt_ids[:, 0:1, :], txt_ids), dim=1)
312
+
313
+ # if txt_ids.ndim == 3:
314
+ # logger.warning(
315
+ # "Passing `txt_ids` 3d torch.Tensor is deprecated."
316
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
317
+ # )
318
+ # txt_ids = txt_ids[0]
319
+ # if img_ids.ndim == 3:
320
+ # logger.warning(
321
+ # "Passing `img_ids` 3d torch.Tensor is deprecated."
322
+ # "Please remove the batch dimension and pass it as a 2d torch Tensor"
323
+ # )
324
+ # img_ids = img_ids[0]
325
+
326
+ # ids = torch.cat((txt_ids, img_ids), dim=0)
327
+ ids = torch.cat((txt_ids, img_ids), dim=1)
328
+ image_rotary_emb = self.pos_embed(ids)
329
+
330
+ block_samples = ()
331
+ for index_block, block in enumerate(self.transformer_blocks):
332
+ if self.training and self.gradient_checkpointing:
333
+
334
+ def create_custom_forward(module, return_dict=None):
335
+ def custom_forward(*inputs):
336
+ if return_dict is not None:
337
+ return module(*inputs, return_dict=return_dict)
338
+ else:
339
+ return module(*inputs)
340
+
341
+ return custom_forward
342
+
343
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(block),
346
+ hidden_states,
347
+ encoder_hidden_states,
348
+ temb,
349
+ image_rotary_emb,
350
+ **ckpt_kwargs,
351
+ )
352
+
353
+ else:
354
+ encoder_hidden_states, hidden_states = block(
355
+ hidden_states=hidden_states,
356
+ encoder_hidden_states=encoder_hidden_states,
357
+ temb=temb,
358
+ image_rotary_emb=image_rotary_emb,
359
+ )
360
+ block_samples = block_samples + (hidden_states,)
361
+
362
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
363
+
364
+ single_block_samples = ()
365
+ for index_block, block in enumerate(self.single_transformer_blocks):
366
+ if self.training and self.gradient_checkpointing:
367
+
368
+ def create_custom_forward(module, return_dict=None):
369
+ def custom_forward(*inputs):
370
+ if return_dict is not None:
371
+ return module(*inputs, return_dict=return_dict)
372
+ else:
373
+ return module(*inputs)
374
+
375
+ return custom_forward
376
+
377
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
+ hidden_states = torch.utils.checkpoint.checkpoint(
379
+ create_custom_forward(block),
380
+ hidden_states,
381
+ temb,
382
+ image_rotary_emb,
383
+ **ckpt_kwargs,
384
+ )
385
+
386
+ else:
387
+ hidden_states = block(
388
+ hidden_states=hidden_states,
389
+ temb=temb,
390
+ image_rotary_emb=image_rotary_emb,
391
+ )
392
+ single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
393
+
394
+ # controlnet block
395
+ controlnet_block_samples = ()
396
+ for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
397
+ block_sample = controlnet_block(block_sample)
398
+ controlnet_block_samples = controlnet_block_samples + (block_sample,)
399
+
400
+ controlnet_single_block_samples = ()
401
+ for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
402
+ single_block_sample = controlnet_block(single_block_sample)
403
+ controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
404
+
405
+ # scaling
406
+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
407
+ controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
408
+
409
+ controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
410
+ controlnet_single_block_samples = (
411
+ None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
412
+ )
413
+
414
+ if USE_PEFT_BACKEND:
415
+ # remove `lora_scale` from each PEFT layer
416
+ unscale_lora_layers(self, lora_scale)
417
+
418
+ if not return_dict:
419
+ return (controlnet_block_samples, controlnet_single_block_samples)
420
+
421
+ return BriaControlNetOutput(
422
+ controlnet_block_samples=controlnet_block_samples,
423
+ controlnet_single_block_samples=controlnet_single_block_samples,
424
+ )
425
+
426
+
427
+ class BriaMultiControlNetModel(ModelMixin):
428
+ r"""
429
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
430
+
431
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
432
+ compatible with `BriaControlNetModel`.
433
+
434
+ Args:
435
+ controlnets (`List[BriaControlNetModel]`):
436
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
437
+ `BriaControlNetModel` as a list.
438
+ """
439
+
440
+ def __init__(self, controlnets):
441
+ super().__init__()
442
+ self.nets = nn.ModuleList(controlnets)
443
+
444
+ def forward(
445
+ self,
446
+ hidden_states: torch.FloatTensor,
447
+ controlnet_cond: List[torch.tensor],
448
+ controlnet_mode: List[torch.tensor],
449
+ conditioning_scale: List[float],
450
+ encoder_hidden_states: torch.Tensor = None,
451
+ pooled_projections: torch.Tensor = None,
452
+ timestep: torch.LongTensor = None,
453
+ img_ids: torch.Tensor = None,
454
+ txt_ids: torch.Tensor = None,
455
+ guidance: torch.Tensor = None,
456
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
457
+ return_dict: bool = True,
458
+ ) -> Union[BriaControlNetOutput, Tuple]:
459
+ # ControlNet-Union with multiple conditions
460
+ # only load one ControlNet for saving memories
461
+ if len(self.nets) == 1 and self.nets[0].union:
462
+ controlnet = self.nets[0]
463
+
464
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
465
+ block_samples, single_block_samples = controlnet(
466
+ hidden_states=hidden_states,
467
+ controlnet_cond=image,
468
+ controlnet_mode=mode[:, None],
469
+ conditioning_scale=scale,
470
+ timestep=timestep,
471
+ guidance=guidance,
472
+ pooled_projections=pooled_projections,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ txt_ids=txt_ids,
475
+ img_ids=img_ids,
476
+ joint_attention_kwargs=joint_attention_kwargs,
477
+ return_dict=return_dict,
478
+ )
479
+
480
+ # merge samples
481
+ if i == 0:
482
+ control_block_samples = block_samples
483
+ control_single_block_samples = single_block_samples
484
+ else:
485
+ control_block_samples = [
486
+ control_block_sample + block_sample
487
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
488
+ ]
489
+
490
+ control_single_block_samples = [
491
+ control_single_block_sample + block_sample
492
+ for control_single_block_sample, block_sample in zip(
493
+ control_single_block_samples, single_block_samples
494
+ )
495
+ ]
496
+
497
+ # Regular Multi-ControlNets
498
+ # load all ControlNets into memories
499
+ else:
500
+ for i, (image, mode, scale, controlnet) in enumerate(
501
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
502
+ ):
503
+ block_samples, single_block_samples = controlnet(
504
+ hidden_states=hidden_states,
505
+ controlnet_cond=image,
506
+ controlnet_mode=mode[:, None],
507
+ conditioning_scale=scale,
508
+ timestep=timestep,
509
+ guidance=guidance,
510
+ pooled_projections=pooled_projections,
511
+ encoder_hidden_states=encoder_hidden_states,
512
+ txt_ids=txt_ids,
513
+ img_ids=img_ids,
514
+ joint_attention_kwargs=joint_attention_kwargs,
515
+ return_dict=return_dict,
516
+ )
517
+
518
+ # merge samples
519
+ if i == 0:
520
+ control_block_samples = block_samples
521
+ control_single_block_samples = single_block_samples
522
+ else:
523
+ if block_samples is not None and control_block_samples is not None:
524
+ control_block_samples = [
525
+ control_block_sample + block_sample
526
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
527
+ ]
528
+ if single_block_samples is not None and control_single_block_samples is not None:
529
+ control_single_block_samples = [
530
+ control_single_block_sample + block_sample
531
+ for control_single_block_sample, block_sample in zip(
532
+ control_single_block_samples, single_block_samples
533
+ )
534
+ ]
535
+
536
+ return control_block_samples, control_single_block_samples
537
+
538
+
539
+
540
+ class BriaMultiControlNetModel(ModelMixin):
541
+ r"""
542
+ `BriaMultiControlNetModel` wrapper class for Multi-BriaControlNetModel
543
+
544
+ This module is a wrapper for multiple instances of the `BriaControlNetModel`. The `forward()` API is designed to be
545
+ compatible with `BriaControlNetModel`.
546
+
547
+ Args:
548
+ controlnets (`List[BriaControlNetModel]`):
549
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
550
+ `BriaControlNetModel` as a list.
551
+ """
552
+
553
+ def __init__(self, controlnets):
554
+ super().__init__()
555
+ self.nets = nn.ModuleList(controlnets)
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.FloatTensor,
560
+ controlnet_cond: List[torch.tensor],
561
+ controlnet_mode: List[torch.tensor],
562
+ conditioning_scale: List[float],
563
+ encoder_hidden_states: torch.Tensor = None,
564
+ pooled_projections: torch.Tensor = None,
565
+ timestep: torch.LongTensor = None,
566
+ img_ids: torch.Tensor = None,
567
+ txt_ids: torch.Tensor = None,
568
+ guidance: torch.Tensor = None,
569
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
570
+ return_dict: bool = True,
571
+ ) -> Union[BriaControlNetOutput, Tuple]:
572
+ # ControlNet-Union with multiple conditions
573
+ # only load one ControlNet for saving memories
574
+ if len(self.nets) == 1 and self.nets[0].union:
575
+ controlnet = self.nets[0]
576
+
577
+ for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
578
+ block_samples, single_block_samples = controlnet(
579
+ hidden_states=hidden_states,
580
+ controlnet_cond=image,
581
+ controlnet_mode=mode[:, None],
582
+ conditioning_scale=scale,
583
+ timestep=timestep,
584
+ guidance=guidance,
585
+ pooled_projections=pooled_projections,
586
+ encoder_hidden_states=encoder_hidden_states,
587
+ txt_ids=txt_ids,
588
+ img_ids=img_ids,
589
+ joint_attention_kwargs=joint_attention_kwargs,
590
+ return_dict=return_dict,
591
+ )
592
+
593
+ # merge samples
594
+ if i == 0:
595
+ control_block_samples = block_samples
596
+ control_single_block_samples = single_block_samples
597
+ else:
598
+ control_block_samples = [
599
+ control_block_sample + block_sample
600
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
601
+ ]
602
+
603
+ control_single_block_samples = [
604
+ control_single_block_sample + block_sample
605
+ for control_single_block_sample, block_sample in zip(
606
+ control_single_block_samples, single_block_samples
607
+ )
608
+ ]
609
+
610
+ # Regular Multi-ControlNets
611
+ # load all ControlNets into memories
612
+ else:
613
+ for i, (image, mode, scale, controlnet) in enumerate(
614
+ zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
615
+ ):
616
+ block_samples, single_block_samples = controlnet(
617
+ hidden_states=hidden_states,
618
+ controlnet_cond=image,
619
+ controlnet_mode=mode[:, None],
620
+ conditioning_scale=scale,
621
+ timestep=timestep,
622
+ guidance=guidance,
623
+ pooled_projections=pooled_projections,
624
+ encoder_hidden_states=encoder_hidden_states,
625
+ txt_ids=txt_ids,
626
+ img_ids=img_ids,
627
+ joint_attention_kwargs=joint_attention_kwargs,
628
+ return_dict=return_dict,
629
+ )
630
+
631
+ # merge samples
632
+ if i == 0:
633
+ control_block_samples = block_samples
634
+ control_single_block_samples = single_block_samples
635
+ else:
636
+ if block_samples is not None and control_block_samples is not None:
637
+ control_block_samples = [
638
+ control_block_sample + block_sample
639
+ for control_block_sample, block_sample in zip(control_block_samples, block_samples)
640
+ ]
641
+ if single_block_samples is not None and control_single_block_samples is not None:
642
+ control_single_block_samples = [
643
+ control_single_block_sample + block_sample
644
+ for control_single_block_sample, block_sample in zip(
645
+ control_single_block_samples, single_block_samples
646
+ )
647
+ ]
648
+
649
+ return control_block_samples, control_single_block_samples
pipeline_bria.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import torch
5
+
6
+ from transformers import (
7
+ T5EncoderModel,
8
+ T5TokenizerFast,
9
+ )
10
+
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from diffusers import AutoencoderKL , DDIMScheduler, EulerAncestralDiscreteScheduler
13
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
14
+ from diffusers.schedulers import KarrasDiffusionSchedulers
15
+ from diffusers.loaders import FluxLoraLoaderMixin
16
+ from diffusers.utils import (
17
+ USE_PEFT_BACKEND,
18
+ is_torch_xla_available,
19
+ logging,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
+ from transformer_bria import BriaTransformer2DModel
27
+ from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
28
+
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
36
+
37
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
+
39
+ EXAMPLE_DOC_STRING = """
40
+ Examples:
41
+ ```py
42
+ >>> import torch
43
+ >>> from diffusers import StableDiffusion3Pipeline
44
+
45
+ >>> pipe = StableDiffusion3Pipeline.from_pretrained(
46
+ ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
47
+ ... )
48
+ >>> pipe.to("cuda")
49
+ >>> prompt = "A cat holding a sign that says hello world"
50
+ >>> image = pipe(prompt).images[0]
51
+ >>> image.save("sd3.png")
52
+ ```
53
+ """
54
+
55
+ T5_PRECISION = torch.float16
56
+
57
+ """
58
+ Based on FluxPipeline with several changes:
59
+ - no pooled embeddings
60
+ - We use zero padding for prompts
61
+ - No guidance embedding since this is not a distilled version
62
+ """
63
+ class BriaPipeline(FluxPipeline):
64
+ r"""
65
+ Args:
66
+ transformer ([`SD3Transformer2DModel`]):
67
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
68
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
69
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
70
+ vae ([`AutoencoderKL`]):
71
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
72
+ text_encoder ([`T5EncoderModel`]):
73
+ Frozen text-encoder. Stable Diffusion 3 uses
74
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
75
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
76
+ tokenizer (`T5TokenizerFast`):
77
+ Tokenizer of class
78
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
79
+ """
80
+
81
+ # model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
82
+ # _optional_components = []
83
+ # _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
84
+
85
+ def __init__(
86
+ self,
87
+ transformer: BriaTransformer2DModel,
88
+ scheduler: Union[FlowMatchEulerDiscreteScheduler,KarrasDiffusionSchedulers],
89
+ vae: AutoencoderKL,
90
+ text_encoder: T5EncoderModel,
91
+ tokenizer: T5TokenizerFast
92
+ ):
93
+ self.register_modules(
94
+ vae=vae,
95
+ text_encoder=text_encoder,
96
+ tokenizer=tokenizer,
97
+ transformer=transformer,
98
+ scheduler=scheduler,
99
+ )
100
+
101
+ # TODO - why different than offical flux (-1)
102
+ self.vae_scale_factor = (
103
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
104
+ )
105
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
106
+ self.default_sample_size = 64 # due to patchify=> 128,128 => res of 1k,1k
107
+
108
+ # T5 is senstive to precision so we use the precision used for precompute and cast as needed
109
+ self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
110
+ for block in self.text_encoder.encoder.block:
111
+ block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
112
+
113
+ def encode_prompt(
114
+ self,
115
+ prompt: Union[str, List[str]],
116
+ device: Optional[torch.device] = None,
117
+ num_images_per_prompt: int = 1,
118
+ do_classifier_free_guidance: bool = True,
119
+ negative_prompt: Optional[Union[str, List[str]]] = None,
120
+ prompt_embeds: Optional[torch.FloatTensor] = None,
121
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
122
+ max_sequence_length: int = 128,
123
+ lora_scale: Optional[float] = None,
124
+ ):
125
+ r"""
126
+
127
+ Args:
128
+ prompt (`str` or `List[str]`, *optional*):
129
+ prompt to be encoded
130
+ device: (`torch.device`):
131
+ torch device
132
+ num_images_per_prompt (`int`):
133
+ number of images that should be generated per prompt
134
+ do_classifier_free_guidance (`bool`):
135
+ whether to use classifier free guidance or not
136
+ negative_prompt (`str` or `List[str]`, *optional*):
137
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
138
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
139
+ less than `1`).
140
+ prompt_embeds (`torch.FloatTensor`, *optional*):
141
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
142
+ provided, text embeddings will be generated from `prompt` input argument.
143
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
144
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
145
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
146
+ argument.
147
+ """
148
+ device = device or self._execution_device
149
+
150
+ # set lora scale so that monkey patched LoRA
151
+ # function of text encoder can correctly access it
152
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
153
+ self._lora_scale = lora_scale
154
+
155
+ # dynamically adjust the LoRA scale
156
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
157
+ scale_lora_layers(self.text_encoder, lora_scale)
158
+
159
+ prompt = [prompt] if isinstance(prompt, str) else prompt
160
+ if prompt is not None:
161
+ batch_size = len(prompt)
162
+ else:
163
+ batch_size = prompt_embeds.shape[0]
164
+
165
+ if prompt_embeds is None:
166
+ prompt_embeds = get_t5_prompt_embeds(
167
+ self.tokenizer,
168
+ self.text_encoder,
169
+ prompt=prompt,
170
+ num_images_per_prompt=num_images_per_prompt,
171
+ max_sequence_length=max_sequence_length,
172
+ device=device,
173
+ ).to(dtype=self.transformer.dtype)
174
+
175
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
176
+ if not is_ng_none(negative_prompt):
177
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
178
+
179
+ if prompt is not None and type(prompt) is not type(negative_prompt):
180
+ raise TypeError(
181
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
182
+ f" {type(prompt)}."
183
+ )
184
+ elif batch_size != len(negative_prompt):
185
+ raise ValueError(
186
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
187
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
188
+ " the batch size of `prompt`."
189
+ )
190
+
191
+ negative_prompt_embeds = get_t5_prompt_embeds(
192
+ self.tokenizer,
193
+ self.text_encoder,
194
+ prompt=negative_prompt,
195
+ num_images_per_prompt=num_images_per_prompt,
196
+ max_sequence_length=max_sequence_length,
197
+ device=device,
198
+ ).to(dtype=self.transformer.dtype)
199
+ else:
200
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
201
+
202
+ if self.text_encoder is not None:
203
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
204
+ # Retrieve the original scale by scaling back the LoRA layers
205
+ unscale_lora_layers(self.text_encoder, lora_scale)
206
+
207
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
208
+ text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
209
+ text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
210
+
211
+ return prompt_embeds, negative_prompt_embeds, text_ids
212
+
213
+ @property
214
+ def guidance_scale(self):
215
+ return self._guidance_scale
216
+
217
+
218
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
219
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
220
+ # corresponds to doing no classifier free guidance.
221
+ @property
222
+ def do_classifier_free_guidance(self):
223
+ return self._guidance_scale > 1
224
+
225
+ @property
226
+ def joint_attention_kwargs(self):
227
+ return self._joint_attention_kwargs
228
+
229
+ @property
230
+ def num_timesteps(self):
231
+ return self._num_timesteps
232
+
233
+ @property
234
+ def interrupt(self):
235
+ return self._interrupt
236
+
237
+ @torch.no_grad()
238
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
239
+ def __call__(
240
+ self,
241
+ prompt: Union[str, List[str]] = None,
242
+ height: Optional[int] = None,
243
+ width: Optional[int] = None,
244
+ num_inference_steps: int = 30,
245
+ timesteps: List[int] = None,
246
+ guidance_scale: float = 5,
247
+ negative_prompt: Optional[Union[str, List[str]]] = None,
248
+ num_images_per_prompt: Optional[int] = 1,
249
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
250
+ latents: Optional[torch.FloatTensor] = None,
251
+ prompt_embeds: Optional[torch.FloatTensor] = None,
252
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
253
+ output_type: Optional[str] = "pil",
254
+ return_dict: bool = True,
255
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
256
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
257
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
258
+ max_sequence_length: int = 128,
259
+ clip_value:Union[None,float] = None,
260
+ normalize:bool = False
261
+ ):
262
+ r"""
263
+ Function invoked when calling the pipeline for generation.
264
+
265
+ Args:
266
+ prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
268
+ instead.
269
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
270
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
271
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
272
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
273
+ num_inference_steps (`int`, *optional*, defaults to 50):
274
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
275
+ expense of slower inference.
276
+ timesteps (`List[int]`, *optional*):
277
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
278
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
279
+ passed will be used. Must be in descending order.
280
+ guidance_scale (`float`, *optional*, defaults to 5.0):
281
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
282
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
283
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
284
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
285
+ usually at the expense of lower image quality.
286
+ negative_prompt (`str` or `List[str]`, *optional*):
287
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
288
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
289
+ less than `1`).
290
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
291
+ The number of images to generate per prompt.
292
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
293
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
294
+ to make generation deterministic.
295
+ latents (`torch.FloatTensor`, *optional*):
296
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
297
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
298
+ tensor will ge generated by sampling using the supplied random `generator`.
299
+ prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301
+ provided, text embeddings will be generated from `prompt` input argument.
302
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
305
+ argument.
306
+ output_type (`str`, *optional*, defaults to `"pil"`):
307
+ The output format of the generate image. Choose between
308
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
309
+ return_dict (`bool`, *optional*, defaults to `True`):
310
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
311
+ of a plain tuple.
312
+ joint_attention_kwargs (`dict`, *optional*):
313
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
314
+ `self.processor` in
315
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
316
+ callback_on_step_end (`Callable`, *optional*):
317
+ A function that calls at the end of each denoising steps during the inference. The function is called
318
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
319
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
320
+ `callback_on_step_end_tensor_inputs`.
321
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
322
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
323
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
324
+ `._callback_tensor_inputs` attribute of your pipeline class.
325
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
326
+
327
+ Examples:
328
+
329
+ Returns:
330
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
331
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
332
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
333
+ """
334
+
335
+ height = height or self.default_sample_size * self.vae_scale_factor
336
+ width = width or self.default_sample_size * self.vae_scale_factor
337
+
338
+ # 1. Check inputs. Raise error if not correct
339
+ self.check_inputs(
340
+ prompt=prompt,
341
+ height=height,
342
+ width=width,
343
+ prompt_embeds=prompt_embeds,
344
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
345
+ max_sequence_length=max_sequence_length,
346
+ )
347
+
348
+ self._guidance_scale = guidance_scale
349
+ self._joint_attention_kwargs = joint_attention_kwargs
350
+ self._interrupt = False
351
+
352
+ # 2. Define call parameters
353
+ if prompt is not None and isinstance(prompt, str):
354
+ batch_size = 1
355
+ elif prompt is not None and isinstance(prompt, list):
356
+ batch_size = len(prompt)
357
+ else:
358
+ batch_size = prompt_embeds.shape[0]
359
+
360
+ device = self._execution_device
361
+
362
+ lora_scale = (
363
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
364
+ )
365
+
366
+ (
367
+ prompt_embeds,
368
+ negative_prompt_embeds,
369
+ text_ids
370
+ ) = self.encode_prompt(
371
+ prompt=prompt,
372
+ negative_prompt=negative_prompt,
373
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
374
+ prompt_embeds=prompt_embeds,
375
+ negative_prompt_embeds=negative_prompt_embeds,
376
+ device=device,
377
+ num_images_per_prompt=num_images_per_prompt,
378
+ max_sequence_length=max_sequence_length,
379
+ lora_scale=lora_scale,
380
+ )
381
+
382
+ if self.do_classifier_free_guidance:
383
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
384
+
385
+ # 4. Prepare timesteps
386
+ # Sample from training sigmas
387
+ if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
388
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
389
+ else:
390
+ sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
391
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
392
+
393
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
394
+ self._num_timesteps = len(timesteps)
395
+
396
+ # 5. Prepare latent variables
397
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
398
+ latents, latent_image_ids = self.prepare_latents(
399
+ batch_size * num_images_per_prompt,
400
+ num_channels_latents,
401
+ height,
402
+ width,
403
+ prompt_embeds.dtype,
404
+ device,
405
+ generator,
406
+ latents,
407
+ )
408
+
409
+ # Supprot different diffusers versions
410
+ if len(latent_image_ids.shape)==2:
411
+ text_ids=text_ids.squeeze()
412
+
413
+ # 6. Denoising loop
414
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
415
+ for i, t in enumerate(timesteps):
416
+ if self.interrupt:
417
+ continue
418
+
419
+ # expand the latents if we are doing classifier free guidance
420
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
421
+ if type(self.scheduler)!=FlowMatchEulerDiscreteScheduler:
422
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
423
+
424
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
425
+ timestep = t.expand(latent_model_input.shape[0])
426
+
427
+ # This is predicts "v" from flow-matching or eps from diffusion
428
+ noise_pred = self.transformer(
429
+ hidden_states=latent_model_input,
430
+ timestep=timestep,
431
+ encoder_hidden_states=prompt_embeds,
432
+ joint_attention_kwargs=self.joint_attention_kwargs,
433
+ return_dict=False,
434
+ txt_ids=text_ids,
435
+ img_ids=latent_image_ids,
436
+ )[0]
437
+
438
+ # perform guidance
439
+ if self.do_classifier_free_guidance:
440
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
441
+ cfg_noise_pred_text = noise_pred_text.std()
442
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
443
+
444
+ if normalize:
445
+ noise_pred = noise_pred * (0.7 *(cfg_noise_pred_text/noise_pred.std())) + 0.3 * noise_pred
446
+
447
+ if clip_value:
448
+ assert clip_value>0
449
+ noise_pred = noise_pred.clip(-clip_value,clip_value)
450
+
451
+ # compute the previous noisy sample x_t -> x_t-1
452
+ latents_dtype = latents.dtype
453
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
454
+
455
+
456
+ # if latents.std().item()>2:
457
+ # print('Warning')
458
+
459
+ # print(t.item(),latents.mean().item(),latents.std().item())
460
+
461
+ if latents.dtype != latents_dtype:
462
+ if torch.backends.mps.is_available():
463
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
464
+ latents = latents.to(latents_dtype)
465
+
466
+ if callback_on_step_end is not None:
467
+ callback_kwargs = {}
468
+ for k in callback_on_step_end_tensor_inputs:
469
+ callback_kwargs[k] = locals()[k]
470
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
471
+
472
+ latents = callback_outputs.pop("latents", latents)
473
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
474
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
475
+
476
+ # call the callback, if provided
477
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
478
+ progress_bar.update()
479
+
480
+ if XLA_AVAILABLE:
481
+ xm.mark_step()
482
+
483
+ if output_type == "latent":
484
+ image = latents
485
+
486
+ else:
487
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
488
+ latents = (latents.to(dtype=torch.float32) / self.vae.config.scaling_factor) + self.vae.config.shift_factor
489
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
490
+ image = self.image_processor.postprocess(image, output_type=output_type)
491
+
492
+ # Offload all models
493
+ self.maybe_free_model_hooks()
494
+
495
+ if not return_dict:
496
+ return (image,)
497
+
498
+ return FluxPipelineOutput(images=image)
499
+
500
+ def check_inputs(
501
+ self,
502
+ prompt,
503
+ height,
504
+ width,
505
+ negative_prompt=None,
506
+ prompt_embeds=None,
507
+ negative_prompt_embeds=None,
508
+ callback_on_step_end_tensor_inputs=None,
509
+ max_sequence_length=None,
510
+ ):
511
+ if height % 8 != 0 or width % 8 != 0:
512
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
513
+
514
+ if callback_on_step_end_tensor_inputs is not None and not all(
515
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
516
+ ):
517
+ raise ValueError(
518
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
519
+ )
520
+
521
+ if prompt is not None and prompt_embeds is not None:
522
+ raise ValueError(
523
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
524
+ " only forward one of the two."
525
+ )
526
+ elif prompt is None and prompt_embeds is None:
527
+ raise ValueError(
528
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
529
+ )
530
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
531
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
532
+
533
+ if negative_prompt is not None and negative_prompt_embeds is not None:
534
+ raise ValueError(
535
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
536
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
537
+ )
538
+
539
+
540
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
541
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
542
+ raise ValueError(
543
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
544
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
545
+ f" {negative_prompt_embeds.shape}."
546
+ )
547
+
548
+ if max_sequence_length is not None and max_sequence_length > 512:
549
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
550
+
551
+ def to(self, *args, **kwargs):
552
+ DiffusionPipeline.to(self, *args, **kwargs)
553
+ # T5 is senstive to precision so we use the precision used for precompute and cast as needed
554
+ self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
555
+ for block in self.text_encoder.encoder.block:
556
+ block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
557
+
558
+ return self
pipeline_bria_controlnet.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Callable, Dict, List, Optional, Union
16
+ import torch
17
+ from transformers import (
18
+ T5EncoderModel,
19
+ T5TokenizerFast,
20
+ )
21
+ from diffusers.image_processor import PipelineImageInput
22
+
23
+ from diffusers import AutoencoderKL # Waiting for diffusers udpdate
24
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
25
+ from diffusers.schedulers import KarrasDiffusionSchedulers
26
+ from diffusers.utils import logging
27
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
28
+ from diffusers.pipelines.flux.pipeline_flux import retrieve_timesteps
29
+ from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel
30
+
31
+ from pipeline_bria import BriaPipeline
32
+ from transformer_bria import BriaTransformer2DModel
33
+ from bria_utils import get_original_sigmas
34
+
35
+ XLA_AVAILABLE = False
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ class BriaControlNetPipeline(BriaPipeline):
42
+ r"""
43
+ Args:
44
+ transformer ([`SD3Transformer2DModel`]):
45
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
46
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
47
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
48
+ vae ([`AutoencoderKL`]):
49
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
50
+ text_encoder ([`T5EncoderModel`]):
51
+ Frozen text-encoder. Stable Diffusion 3 uses
52
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
53
+ [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
54
+ tokenizer (`T5TokenizerFast`):
55
+ Tokenizer of class
56
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
57
+ """
58
+
59
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
60
+ _optional_components = []
61
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
62
+
63
+ def __init__( # EYAL - removed clip text encoder + tokenizer
64
+ self,
65
+ transformer: BriaTransformer2DModel,
66
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
67
+ vae: AutoencoderKL,
68
+ text_encoder: T5EncoderModel,
69
+ tokenizer: T5TokenizerFast,
70
+ controlnet: BriaControlNetModel,
71
+ ):
72
+ super().__init__(
73
+ transformer=transformer, scheduler=scheduler, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer
74
+ )
75
+ self.register_modules(controlnet=controlnet)
76
+
77
+ def prepare_image(
78
+ self,
79
+ image,
80
+ width,
81
+ height,
82
+ batch_size,
83
+ num_images_per_prompt,
84
+ device,
85
+ dtype,
86
+ do_classifier_free_guidance=False,
87
+ guess_mode=False,
88
+ ):
89
+ if isinstance(image, torch.Tensor):
90
+ pass
91
+ else:
92
+ image = self.image_processor.preprocess(image, height=height, width=width)
93
+
94
+ image_batch_size = image.shape[0]
95
+
96
+ if image_batch_size == 1:
97
+ repeat_by = batch_size
98
+ else:
99
+ # image batch size is the same as prompt batch size
100
+ repeat_by = num_images_per_prompt
101
+
102
+ image = image.repeat_interleave(repeat_by, dim=0)
103
+
104
+ image = image.to(device=device, dtype=dtype)
105
+
106
+ if do_classifier_free_guidance and not guess_mode:
107
+ image = torch.cat([image] * 2)
108
+
109
+ return image
110
+
111
+ def prepare_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
112
+ num_channels_latents = self.transformer.config.in_channels // 4
113
+ control_image = self.prepare_image(
114
+ image=control_image,
115
+ width=width,
116
+ height=height,
117
+ batch_size=batch_size * num_images_per_prompt,
118
+ num_images_per_prompt=num_images_per_prompt,
119
+ device=device,
120
+ dtype=self.vae.dtype,
121
+ )
122
+ height, width = control_image.shape[-2:]
123
+
124
+ # vae encode
125
+ control_image = self.vae.encode(control_image).latent_dist.sample()
126
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
127
+
128
+ # pack
129
+ height_control_image, width_control_image = control_image.shape[2:]
130
+ control_image = self._pack_latents(
131
+ control_image,
132
+ batch_size * num_images_per_prompt,
133
+ num_channels_latents,
134
+ height_control_image,
135
+ width_control_image,
136
+ )
137
+
138
+ # Here we ensure that `control_mode` has the same length as the control_image.
139
+ if control_mode is not None:
140
+ if not isinstance(control_mode, int):
141
+ raise ValueError(" For `BriaControlNet`, `control_mode` should be an `int` or `None`")
142
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
143
+ control_mode = control_mode.view(-1, 1).expand(control_image.shape[0], 1)
144
+
145
+ return control_image, control_mode
146
+
147
+ def prepare_multi_control(self, control_image, width, height, batch_size, num_images_per_prompt, device, control_mode):
148
+ num_channels_latents = self.transformer.config.in_channels // 4
149
+ control_images = []
150
+ for i, control_image_ in enumerate(control_image):
151
+ control_image_ = self.prepare_image(
152
+ image=control_image_,
153
+ width=width,
154
+ height=height,
155
+ batch_size=batch_size * num_images_per_prompt,
156
+ num_images_per_prompt=num_images_per_prompt,
157
+ device=device,
158
+ dtype=self.vae.dtype,
159
+ )
160
+ height, width = control_image_.shape[-2:]
161
+
162
+ # vae encode
163
+ control_image_ = self.vae.encode(control_image_).latent_dist.sample()
164
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
165
+
166
+ # pack
167
+ height_control_image, width_control_image = control_image_.shape[2:]
168
+ control_image_ = self._pack_latents(
169
+ control_image_,
170
+ batch_size * num_images_per_prompt,
171
+ num_channels_latents,
172
+ height_control_image,
173
+ width_control_image,
174
+ )
175
+ control_images.append(control_image_)
176
+
177
+ control_image = control_images
178
+
179
+ # Here we ensure that `control_mode` has the same length as the control_image.
180
+ if isinstance(control_mode, list) and len(control_mode) != len(control_image):
181
+ raise ValueError(
182
+ "For Multi-ControlNet, `control_mode` must be a list of the same "
183
+ + " length as the number of controlnets (control images) specified"
184
+ )
185
+ if not isinstance(control_mode, list):
186
+ control_mode = [control_mode] * len(control_image)
187
+ # set control mode
188
+ control_modes = []
189
+ for cmode in control_mode:
190
+ if cmode is None:
191
+ cmode = -1
192
+ control_mode = torch.tensor(cmode).expand(control_images[0].shape[0]).to(device, dtype=torch.long)
193
+ control_modes.append(control_mode)
194
+ control_mode = control_modes
195
+
196
+ return control_image, control_mode
197
+
198
+ def get_controlnet_keep(self, timesteps, control_guidance_start, control_guidance_end):
199
+ controlnet_keep = []
200
+ for i in range(len(timesteps)):
201
+ keeps = [
202
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
203
+ for s, e in zip(control_guidance_start, control_guidance_end)
204
+ ]
205
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, BriaControlNetModel) else keeps)
206
+ return controlnet_keep
207
+
208
+ def get_control_start_end(self, control_guidance_start, control_guidance_end):
209
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
210
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
211
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
212
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
213
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
214
+ mult = 1 # TODO - why is this 1?
215
+ control_guidance_start, control_guidance_end = (
216
+ mult * [control_guidance_start],
217
+ mult * [control_guidance_end],
218
+ )
219
+
220
+ return control_guidance_start, control_guidance_end
221
+
222
+ @torch.no_grad()
223
+ def __call__(
224
+ self,
225
+ prompt: Union[str, List[str]] = None,
226
+ height: Optional[int] = None,
227
+ width: Optional[int] = None,
228
+ num_inference_steps: int = 30,
229
+ timesteps: List[int] = None,
230
+ guidance_scale: float = 3.5,
231
+ control_guidance_start: Union[float, List[float]] = 0.0,
232
+ control_guidance_end: Union[float, List[float]] = 1.0,
233
+ control_image: Optional[PipelineImageInput] = None,
234
+ control_mode: Optional[Union[int, List[int]]] = None,
235
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
236
+ negative_prompt: Optional[Union[str, List[str]]] = None,
237
+ num_images_per_prompt: Optional[int] = 1,
238
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
239
+ latents: Optional[torch.FloatTensor] = None,
240
+ prompt_embeds: Optional[torch.FloatTensor] = None,
241
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
242
+ output_type: Optional[str] = "pil",
243
+ return_dict: bool = True,
244
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
245
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
246
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
247
+ max_sequence_length: int = 128,
248
+ ):
249
+ r"""
250
+ Function invoked when calling the pipeline for generation.
251
+
252
+ Args:
253
+ prompt (`str` or `List[str]`, *optional*):
254
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
255
+ instead.
256
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
257
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
258
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
259
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
260
+ num_inference_steps (`int`, *optional*, defaults to 50):
261
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
262
+ expense of slower inference.
263
+ timesteps (`List[int]`, *optional*):
264
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
265
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
266
+ passed will be used. Must be in descending order.
267
+ guidance_scale (`float`, *optional*, defaults to 5.0):
268
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
269
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
270
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
271
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
272
+ usually at the expense of lower image quality.
273
+ negative_prompt (`str` or `List[str]`, *optional*):
274
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
275
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
276
+ less than `1`).
277
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
278
+ The number of images to generate per prompt.
279
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
280
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
281
+ to make generation deterministic.
282
+ latents (`torch.FloatTensor`, *optional*):
283
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
284
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
285
+ tensor will ge generated by sampling using the supplied random `generator`.
286
+ prompt_embeds (`torch.FloatTensor`, *optional*):
287
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
288
+ provided, text embeddings will be generated from `prompt` input argument.
289
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
290
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
291
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
292
+ argument.
293
+ output_type (`str`, *optional*, defaults to `"pil"`):
294
+ The output format of the generate image. Choose between
295
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
296
+ return_dict (`bool`, *optional*, defaults to `True`):
297
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
298
+ of a plain tuple.
299
+ joint_attention_kwargs (`dict`, *optional*):
300
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
301
+ `self.processor` in
302
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
303
+ callback_on_step_end (`Callable`, *optional*):
304
+ A function that calls at the end of each denoising steps during the inference. The function is called
305
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
306
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
307
+ `callback_on_step_end_tensor_inputs`.
308
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
309
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
310
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
311
+ `._callback_tensor_inputs` attribute of your pipeline class.
312
+ max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
313
+
314
+ Examples:
315
+
316
+ Returns:
317
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
318
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
319
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
320
+ """
321
+
322
+ height = height or self.default_sample_size * self.vae_scale_factor
323
+ width = width or self.default_sample_size * self.vae_scale_factor
324
+ control_guidance_start, control_guidance_end = self.get_control_start_end(
325
+ control_guidance_start=control_guidance_start, control_guidance_end=control_guidance_end
326
+ )
327
+
328
+ # 1. Check inputs. Raise error if not correct
329
+ self.check_inputs(
330
+ prompt,
331
+ height,
332
+ width,
333
+ negative_prompt=negative_prompt,
334
+ prompt_embeds=prompt_embeds,
335
+ negative_prompt_embeds=negative_prompt_embeds,
336
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
337
+ max_sequence_length=max_sequence_length,
338
+ )
339
+
340
+ self._guidance_scale = guidance_scale
341
+ self._joint_attention_kwargs = joint_attention_kwargs
342
+ self._interrupt = False
343
+
344
+ # 2. Define call parameters
345
+ if prompt is not None and isinstance(prompt, str):
346
+ batch_size = 1
347
+ elif prompt is not None and isinstance(prompt, list):
348
+ batch_size = len(prompt)
349
+ else:
350
+ batch_size = prompt_embeds.shape[0]
351
+
352
+ device = self._execution_device
353
+
354
+ lora_scale = self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
355
+
356
+ (prompt_embeds, negative_prompt_embeds, text_ids) = self.encode_prompt(
357
+ prompt=prompt,
358
+ negative_prompt=negative_prompt,
359
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
360
+ prompt_embeds=prompt_embeds,
361
+ negative_prompt_embeds=negative_prompt_embeds,
362
+ device=device,
363
+ num_images_per_prompt=num_images_per_prompt,
364
+ max_sequence_length=max_sequence_length,
365
+ lora_scale=lora_scale,
366
+ )
367
+
368
+ if self.do_classifier_free_guidance:
369
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
370
+
371
+ # 3. Prepare control image
372
+ if control_image is not None:
373
+ if isinstance(self.controlnet, BriaControlNetModel):
374
+ control_image, control_mode = self.prepare_control(
375
+ control_image=control_image,
376
+ width=width,
377
+ height=height,
378
+ batch_size=batch_size,
379
+ num_images_per_prompt=num_images_per_prompt,
380
+ device=device,
381
+ control_mode=control_mode,
382
+ )
383
+ elif isinstance(self.controlnet, BriaMultiControlNetModel):
384
+ control_image, control_mode = self.prepare_multi_control(
385
+ control_image=control_image,
386
+ width=width,
387
+ height=height,
388
+ batch_size=batch_size,
389
+ num_images_per_prompt=num_images_per_prompt,
390
+ device=device,
391
+ control_mode=control_mode,
392
+ )
393
+
394
+ # 4. Prepare timesteps
395
+ # Sample from training sigmas
396
+ sigmas = get_original_sigmas(
397
+ num_train_timesteps=self.scheduler.config.num_train_timesteps, num_inference_steps=num_inference_steps
398
+ )
399
+ timesteps, num_inference_steps = retrieve_timesteps(
400
+ self.scheduler, num_inference_steps, device, timesteps, sigmas=sigmas
401
+ )
402
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
403
+ self._num_timesteps = len(timesteps)
404
+
405
+ # 5. Prepare latent variables
406
+ num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
407
+ latents, latent_image_ids = self.prepare_latents(
408
+ batch_size=batch_size * num_images_per_prompt,
409
+ num_channels_latents=num_channels_latents,
410
+ height=height,
411
+ width=width,
412
+ dtype=prompt_embeds.dtype,
413
+ device=device,
414
+ generator=generator,
415
+ latents=latents,
416
+ )
417
+
418
+ # 6. Create tensor stating which controlnets to keep
419
+ if control_image is not None:
420
+ controlnet_keep = self.get_controlnet_keep(
421
+ timesteps=timesteps,
422
+ control_guidance_start=control_guidance_start,
423
+ control_guidance_end=control_guidance_end,
424
+ )
425
+
426
+ # EYAL - added the CFG loop
427
+ # 7. Denoising loop
428
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
429
+ for i, t in enumerate(timesteps):
430
+ if self.interrupt:
431
+ continue
432
+
433
+ # expand the latents if we are doing classifier free guidance
434
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
435
+ # if type(self.scheduler) != FlowMatchEulerDiscreteScheduler:
436
+ if not isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
437
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
438
+
439
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
440
+ timestep = t.expand(latent_model_input.shape[0])
441
+
442
+ # Handling ControlNet
443
+ if control_image is not None:
444
+ if isinstance(controlnet_keep[i], list):
445
+ if isinstance(controlnet_conditioning_scale, list):
446
+ cond_scale = controlnet_conditioning_scale
447
+ else:
448
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
449
+ else:
450
+ controlnet_cond_scale = controlnet_conditioning_scale
451
+ if isinstance(controlnet_cond_scale, list):
452
+ controlnet_cond_scale = controlnet_cond_scale[0]
453
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
454
+
455
+ # controlnet
456
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
457
+ hidden_states=latents,
458
+ controlnet_cond=control_image,
459
+ controlnet_mode=control_mode,
460
+ conditioning_scale=cond_scale,
461
+ timestep=timestep,
462
+ # guidance=guidance,
463
+ # pooled_projections=pooled_prompt_embeds,
464
+ encoder_hidden_states=prompt_embeds,
465
+ txt_ids=text_ids,
466
+ img_ids=latent_image_ids,
467
+ joint_attention_kwargs=self.joint_attention_kwargs,
468
+ return_dict=False,
469
+ )
470
+ else:
471
+ controlnet_block_samples, controlnet_single_block_samples = None, None
472
+
473
+ # This is predicts "v" from flow-matching
474
+ noise_pred = self.transformer(
475
+ hidden_states=latent_model_input,
476
+ timestep=timestep,
477
+ encoder_hidden_states=prompt_embeds,
478
+ joint_attention_kwargs=self.joint_attention_kwargs,
479
+ return_dict=False,
480
+ txt_ids=text_ids,
481
+ img_ids=latent_image_ids,
482
+ controlnet_block_samples=controlnet_block_samples,
483
+ controlnet_single_block_samples=controlnet_single_block_samples,
484
+ )[0]
485
+
486
+ # perform guidance
487
+ if self.do_classifier_free_guidance:
488
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
489
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
490
+
491
+ # compute the previous noisy sample x_t -> x_t-1
492
+ latents_dtype = latents.dtype
493
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
494
+
495
+ if latents.dtype != latents_dtype:
496
+ if torch.backends.mps.is_available():
497
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
498
+ latents = latents.to(latents_dtype)
499
+
500
+ if callback_on_step_end is not None:
501
+ callback_kwargs = {}
502
+ for k in callback_on_step_end_tensor_inputs:
503
+ callback_kwargs[k] = locals()[k]
504
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
505
+
506
+ latents = callback_outputs.pop("latents", latents)
507
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
508
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
509
+
510
+ # call the callback, if provided
511
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
512
+ progress_bar.update()
513
+
514
+ if XLA_AVAILABLE:
515
+ xm.mark_step()
516
+
517
+ if output_type == "latent":
518
+ image = latents
519
+
520
+ else:
521
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
522
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
523
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
524
+ image = self.image_processor.postprocess(image, output_type=output_type)
525
+
526
+ # Offload all models
527
+ self.maybe_free_model_hooks()
528
+
529
+ if not return_dict:
530
+ return (image,)
531
+
532
+ return FluxPipelineOutput(images=image)
transformer_bria.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.loaders import PeftAdapterMixin, FromOriginalModelMixin
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+ from diffusers.models.normalization import AdaLayerNormContinuous
9
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
10
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
11
+ from diffusers.models.embeddings import TimestepEmbedding, get_timestep_embedding
12
+ from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
13
+
14
+ # Support different diffusers versions
15
+ try:
16
+ from diffusers.models.embeddings import FluxPosEmbed as EmbedND
17
+ except:
18
+ from diffusers.models.transformers.transformer_flux import rope
19
+ class EmbedND(nn.Module):
20
+ def __init__(self, theta: int, axes_dim: List[int]):
21
+ super().__init__()
22
+ self.theta = theta
23
+ self.axes_dim = axes_dim
24
+
25
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
26
+ n_axes = ids.shape[-1]
27
+ emb = torch.cat(
28
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
29
+ dim=-3,
30
+ )
31
+ return emb.unsqueeze(1)
32
+
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ class Timesteps(nn.Module):
38
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1,max_period=10000):
39
+ super().__init__()
40
+ self.num_channels = num_channels
41
+ self.flip_sin_to_cos = flip_sin_to_cos
42
+ self.downscale_freq_shift = downscale_freq_shift
43
+ self.scale = scale
44
+ self.max_period=max_period
45
+
46
+ def forward(self, timesteps):
47
+ t_emb = get_timestep_embedding(
48
+ timesteps,
49
+ self.num_channels,
50
+ flip_sin_to_cos=self.flip_sin_to_cos,
51
+ downscale_freq_shift=self.downscale_freq_shift,
52
+ scale=self.scale,
53
+ max_period=self.max_period
54
+ )
55
+ return t_emb
56
+
57
+ class TimestepProjEmbeddings(nn.Module):
58
+ def __init__(self, embedding_dim, max_period):
59
+ super().__init__()
60
+
61
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0,max_period=max_period)
62
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
63
+
64
+ def forward(self, timestep, dtype):
65
+ timesteps_proj = self.time_proj(timestep)
66
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
67
+ return timesteps_emb
68
+
69
+ """
70
+ Based on FluxPipeline with several changes:
71
+ - no pooled embeddings
72
+ - We use zero padding for prompts
73
+ - No guidance embedding since this is not a distilled version
74
+ """
75
+ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
76
+ """
77
+ The Transformer model introduced in Flux.
78
+
79
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
80
+
81
+ Parameters:
82
+ patch_size (`int`): Patch size to turn the input data into small patches.
83
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
84
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
85
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
86
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
87
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
88
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
89
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
90
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
91
+ """
92
+
93
+ _supports_gradient_checkpointing = True
94
+
95
+ @register_to_config
96
+ def __init__(
97
+ self,
98
+ patch_size: int = 1,
99
+ in_channels: int = 64,
100
+ num_layers: int = 19,
101
+ num_single_layers: int = 38,
102
+ attention_head_dim: int = 128,
103
+ num_attention_heads: int = 24,
104
+ joint_attention_dim: int = 4096,
105
+ pooled_projection_dim: int = None,
106
+ guidance_embeds: bool = False,
107
+ axes_dims_rope: List[int] = [16, 56, 56],
108
+ rope_theta = 10000,
109
+ max_period = 10000
110
+ ):
111
+ super().__init__()
112
+ self.out_channels = in_channels
113
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
114
+
115
+ self.pos_embed = EmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
116
+
117
+
118
+ self.time_embed = TimestepProjEmbeddings(
119
+ embedding_dim=self.inner_dim,max_period=max_period
120
+ )
121
+
122
+ # if pooled_projection_dim:
123
+ # self.pooled_text_embed = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim=self.inner_dim, act_fn="silu")
124
+
125
+ if guidance_embeds:
126
+ self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
127
+
128
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
129
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
130
+
131
+ self.transformer_blocks = nn.ModuleList(
132
+ [
133
+ FluxTransformerBlock(
134
+ dim=self.inner_dim,
135
+ num_attention_heads=self.config.num_attention_heads,
136
+ attention_head_dim=self.config.attention_head_dim,
137
+ )
138
+ for i in range(self.config.num_layers)
139
+ ]
140
+ )
141
+
142
+ self.single_transformer_blocks = nn.ModuleList(
143
+ [
144
+ FluxSingleTransformerBlock(
145
+ dim=self.inner_dim,
146
+ num_attention_heads=self.config.num_attention_heads,
147
+ attention_head_dim=self.config.attention_head_dim,
148
+ )
149
+ for i in range(self.config.num_single_layers)
150
+ ]
151
+ )
152
+
153
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
154
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
155
+
156
+ self.gradient_checkpointing = False
157
+
158
+ def _set_gradient_checkpointing(self, module, value=False):
159
+ if hasattr(module, "gradient_checkpointing"):
160
+ module.gradient_checkpointing = value
161
+
162
+ def forward(
163
+ self,
164
+ hidden_states: torch.Tensor,
165
+ encoder_hidden_states: torch.Tensor = None,
166
+ pooled_projections: torch.Tensor = None,
167
+ timestep: torch.LongTensor = None,
168
+ img_ids: torch.Tensor = None,
169
+ txt_ids: torch.Tensor = None,
170
+ guidance: torch.Tensor = None,
171
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
172
+ return_dict: bool = True,
173
+ controlnet_block_samples = None,
174
+ controlnet_single_block_samples=None,
175
+
176
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
177
+ """
178
+ The [`FluxTransformer2DModel`] forward method.
179
+
180
+ Args:
181
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
182
+ Input `hidden_states`.
183
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
184
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
185
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
186
+ from the embeddings of input conditions.
187
+ timestep ( `torch.LongTensor`):
188
+ Used to indicate denoising step.
189
+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
190
+ A list of tensors that if specified are added to the residuals of transformer blocks.
191
+ joint_attention_kwargs (`dict`, *optional*):
192
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
193
+ `self.processor` in
194
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
195
+ return_dict (`bool`, *optional*, defaults to `True`):
196
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
197
+ tuple.
198
+
199
+ Returns:
200
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
201
+ `tuple` where the first element is the sample tensor.
202
+ """
203
+ if joint_attention_kwargs is not None:
204
+ joint_attention_kwargs = joint_attention_kwargs.copy()
205
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
206
+ else:
207
+ lora_scale = 1.0
208
+
209
+ if USE_PEFT_BACKEND:
210
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
211
+ scale_lora_layers(self, lora_scale)
212
+ else:
213
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
214
+ logger.warning(
215
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
216
+ )
217
+ hidden_states = self.x_embedder(hidden_states)
218
+
219
+ timestep = timestep.to(hidden_states.dtype)
220
+ if guidance is not None:
221
+ guidance = guidance.to(hidden_states.dtype)
222
+ else:
223
+ guidance = None
224
+
225
+ # temb = (
226
+ # self.time_text_embed(timestep, pooled_projections)
227
+ # if guidance is None
228
+ # else self.time_text_embed(timestep, guidance, pooled_projections)
229
+ # )
230
+
231
+ temb = self.time_embed(timestep,dtype=hidden_states.dtype)
232
+
233
+ # if pooled_projections:
234
+ # temb+=self.pooled_text_embed(pooled_projections)
235
+
236
+ if guidance:
237
+ temb+=self.guidance_embed(guidance,dtype=hidden_states.dtype)
238
+
239
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
240
+
241
+ if len(txt_ids.shape)==2:
242
+ ids = torch.cat((txt_ids, img_ids), dim=0)
243
+ else:
244
+ ids = torch.cat((txt_ids, img_ids), dim=1)
245
+ image_rotary_emb = self.pos_embed(ids)
246
+
247
+ for index_block, block in enumerate(self.transformer_blocks):
248
+ if self.training and self.gradient_checkpointing:
249
+
250
+ def create_custom_forward(module, return_dict=None):
251
+ def custom_forward(*inputs):
252
+ if return_dict is not None:
253
+ return module(*inputs, return_dict=return_dict)
254
+ else:
255
+ return module(*inputs)
256
+
257
+ return custom_forward
258
+
259
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
260
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
261
+ create_custom_forward(block),
262
+ hidden_states,
263
+ encoder_hidden_states,
264
+ temb,
265
+ image_rotary_emb,
266
+ **ckpt_kwargs,
267
+ )
268
+
269
+ else:
270
+ encoder_hidden_states, hidden_states = block(
271
+ hidden_states=hidden_states,
272
+ encoder_hidden_states=encoder_hidden_states,
273
+ temb=temb,
274
+ image_rotary_emb=image_rotary_emb,
275
+ )
276
+
277
+ # controlnet residual
278
+ if controlnet_block_samples is not None:
279
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
280
+ interval_control = int(np.ceil(interval_control))
281
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
282
+
283
+
284
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
285
+
286
+ for index_block, block in enumerate(self.single_transformer_blocks):
287
+ if self.training and self.gradient_checkpointing:
288
+
289
+ def create_custom_forward(module, return_dict=None):
290
+ def custom_forward(*inputs):
291
+ if return_dict is not None:
292
+ return module(*inputs, return_dict=return_dict)
293
+ else:
294
+ return module(*inputs)
295
+
296
+ return custom_forward
297
+
298
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
299
+ hidden_states = torch.utils.checkpoint.checkpoint(
300
+ create_custom_forward(block),
301
+ hidden_states,
302
+ temb,
303
+ image_rotary_emb,
304
+ **ckpt_kwargs,
305
+ )
306
+
307
+ else:
308
+ hidden_states = block(
309
+ hidden_states=hidden_states,
310
+ temb=temb,
311
+ image_rotary_emb=image_rotary_emb,
312
+ )
313
+
314
+ # controlnet residual
315
+ if controlnet_single_block_samples is not None:
316
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
317
+ interval_control = int(np.ceil(interval_control))
318
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
319
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
320
+ + controlnet_single_block_samples[index_block // interval_control]
321
+ )
322
+
323
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
324
+
325
+ hidden_states = self.norm_out(hidden_states, temb)
326
+ output = self.proj_out(hidden_states)
327
+
328
+ if USE_PEFT_BACKEND:
329
+ # remove `lora_scale` from each PEFT layer
330
+ unscale_lora_layers(self, lora_scale)
331
+
332
+ if not return_dict:
333
+ return (output,)
334
+
335
+ return Transformer2DModelOutput(sample=output)
336
+