fumeisama commited on
Commit
a5418f5
·
verified ·
1 Parent(s): 2e91294

Create pixart_transformer_modified.py

Browse files
transformer/pixart_transformer_modified.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Union, Tuple, List
15
+
16
+ import torch
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
21
+ from diffusers.utils import is_torch_version, logging, deprecate
22
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
23
+ from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0, JointAttnProcessor2_0
24
+ from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection
25
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
26
+ from diffusers.models.modeling_utils import ModelMixin
27
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, AdaLayerNormSingle
28
+ from torch.nn.utils.rnn import pad_sequence
29
+ from einops import rearrange
30
+ import numpy as np
31
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
32
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
39
+ r"""
40
+ A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
41
+ https://arxiv.org/abs/2403.04692).
42
+
43
+ Parameters:
44
+ num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
45
+ attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
46
+ in_channels (int, defaults to 4): The number of channels in the input.
47
+ out_channels (int, optional):
48
+ The number of channels in the output. Specify this parameter if the output channel number differs from the
49
+ input.
50
+ num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
51
+ dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
52
+ norm_num_groups (int, optional, defaults to 32):
53
+ Number of groups for group normalization within Transformer blocks.
54
+ cross_attention_dim (int, optional):
55
+ The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
56
+ attention_bias (bool, optional, defaults to True):
57
+ Configure if the Transformer blocks' attention should contain a bias parameter.
58
+ sample_size (int, defaults to 128):
59
+ The width of the latent images. This parameter is fixed during training.
60
+ patch_size (int, defaults to 2):
61
+ Size of the patches the model processes, relevant for architectures working on non-sequential data.
62
+ activation_fn (str, optional, defaults to "gelu-approximate"):
63
+ Activation function to use in feed-forward networks within Transformer blocks.
64
+ num_embeds_ada_norm (int, optional, defaults to 1000):
65
+ Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
66
+ inference.
67
+ upcast_attention (bool, optional, defaults to False):
68
+ If true, upcasts the attention mechanism dimensions for potentially improved performance.
69
+ norm_type (str, optional, defaults to "ada_norm_zero"):
70
+ Specifies the type of normalization used, can be 'ada_norm_zero'.
71
+ norm_elementwise_affine (bool, optional, defaults to False):
72
+ If true, enables element-wise affine parameters in the normalization layers.
73
+ norm_eps (float, optional, defaults to 1e-6):
74
+ A small constant added to the denominator in normalization layers to prevent division by zero.
75
+ interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
76
+ use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
77
+ attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
78
+ caption_channels (int, optional, defaults to None):
79
+ Number of channels to use for projecting the caption embeddings.
80
+ use_linear_projection (bool, optional, defaults to False):
81
+ Deprecated argument. Will be removed in a future version.
82
+ num_vector_embeds (bool, optional, defaults to False):
83
+ Deprecated argument. Will be removed in a future version.
84
+ """
85
+
86
+ _supports_gradient_checkpointing = True
87
+ _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
88
+
89
+ @register_to_config
90
+ def __init__(
91
+ self,
92
+ num_attention_heads: int = 16,
93
+ attention_head_dim: int = 72,
94
+ in_channels: int = 4,
95
+ out_channels: Optional[int] = 8,
96
+ num_layers: int = 28,
97
+ dropout: float = 0.0,
98
+ norm_num_groups: int = 32,
99
+ cross_attention_dim: Optional[int] = 1152,
100
+ attention_bias: bool = True,
101
+ sample_size: int = 128,
102
+ patch_size: int = 2,
103
+ activation_fn: str = "gelu-approximate",
104
+ num_embeds_ada_norm: Optional[int] = 1000,
105
+ upcast_attention: bool = False,
106
+ norm_type: str = "ada_norm_single",
107
+ norm_elementwise_affine: bool = False,
108
+ norm_eps: float = 1e-6,
109
+ interpolation_scale: Optional[int] = None,
110
+ use_additional_conditions: Optional[bool] = None,
111
+ caption_channels: Optional[int] = None,
112
+ attention_type: Optional[str] = "default",
113
+ ):
114
+ super().__init__()
115
+
116
+ # Validate inputs.
117
+ if norm_type != "ada_norm_single":
118
+ raise NotImplementedError(
119
+ f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
120
+ )
121
+ elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
122
+ raise ValueError(
123
+ f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
124
+ )
125
+
126
+ # Set some common variables used across the board.
127
+ self.attention_head_dim = attention_head_dim
128
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
129
+ self.out_channels = in_channels if out_channels is None else out_channels
130
+ if use_additional_conditions is None:
131
+ if sample_size == 128:
132
+ use_additional_conditions = True
133
+ else:
134
+ use_additional_conditions = False
135
+ self.use_additional_conditions = use_additional_conditions
136
+
137
+ self.gradient_checkpointing = False
138
+
139
+ # 2. Initialize the position embedding and transformer blocks.
140
+ self.height = self.config.sample_size
141
+ self.width = self.config.sample_size
142
+
143
+ interpolation_scale = (
144
+ self.config.interpolation_scale
145
+ if self.config.interpolation_scale is not None
146
+ else max(self.config.sample_size // 64, 1)
147
+ )
148
+ self.pos_embed = PatchEmbed(
149
+ height=self.config.sample_size,
150
+ width=self.config.sample_size,
151
+ patch_size=self.config.patch_size,
152
+ in_channels=self.config.in_channels,
153
+ embed_dim=self.inner_dim,
154
+ interpolation_scale=interpolation_scale,
155
+ )
156
+
157
+ self.transformer_blocks = nn.ModuleList(
158
+ [
159
+ BasicTransformerBlock(
160
+ self.inner_dim,
161
+ self.config.num_attention_heads,
162
+ self.config.attention_head_dim,
163
+ dropout=self.config.dropout,
164
+ cross_attention_dim=self.config.cross_attention_dim,
165
+ activation_fn=self.config.activation_fn,
166
+ num_embeds_ada_norm=self.config.num_embeds_ada_norm,
167
+ attention_bias=self.config.attention_bias,
168
+ upcast_attention=self.config.upcast_attention,
169
+ norm_type=norm_type,
170
+ norm_elementwise_affine=self.config.norm_elementwise_affine,
171
+ norm_eps=self.config.norm_eps,
172
+ attention_type=self.config.attention_type,
173
+ )
174
+ for _ in range(self.config.num_layers)
175
+ ]
176
+ )
177
+
178
+ # 3. Output blocks.
179
+ self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
180
+ self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
181
+ self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
182
+
183
+ self.adaln_single = AdaLayerNormSingle(
184
+ self.inner_dim, use_additional_conditions=self.use_additional_conditions
185
+ )
186
+ self.caption_projection = None
187
+ if self.config.caption_channels is not None:
188
+ self.caption_projection = PixArtAlphaTextProjection(
189
+ in_features=self.config.caption_channels, hidden_size=self.inner_dim
190
+ )
191
+ self.ip_adapter = IPAdapter()
192
+
193
+ def _set_gradient_checkpointing(self, module, value=False):
194
+ if hasattr(module, "gradient_checkpointing"):
195
+ module.gradient_checkpointing = value
196
+
197
+ @property
198
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
199
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
200
+ r"""
201
+ Returns:
202
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
203
+ indexed by its weight name.
204
+ """
205
+ # set recursively
206
+ processors = {}
207
+
208
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
209
+ if hasattr(module, "get_processor"):
210
+ processors[f"{name}.processor"] = module.get_processor()
211
+
212
+ for sub_name, child in module.named_children():
213
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
214
+
215
+ return processors
216
+
217
+ for name, module in self.named_children():
218
+ fn_recursive_add_processors(name, module, processors)
219
+
220
+ return processors
221
+
222
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
223
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
224
+ r"""
225
+ Sets the attention processor to use to compute attention.
226
+
227
+ Parameters:
228
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
229
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
230
+ for **all** `Attention` layers.
231
+
232
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
233
+ processor. This is strongly recommended when setting trainable attention processors.
234
+
235
+ """
236
+ count = len(self.attn_processors.keys())
237
+
238
+ if isinstance(processor, dict) and len(processor) != count:
239
+ raise ValueError(
240
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
241
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
242
+ )
243
+
244
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
245
+ if hasattr(module, "set_processor"):
246
+ if not isinstance(processor, dict):
247
+ module.set_processor(processor)
248
+ else:
249
+ module.set_processor(processor.pop(f"{name}.processor"))
250
+
251
+ for sub_name, child in module.named_children():
252
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
253
+
254
+ for name, module in self.named_children():
255
+ fn_recursive_attn_processor(name, module, processor)
256
+
257
+ def set_default_attn_processor(self):
258
+ """
259
+ Disables custom attention processors and sets the default attention implementation.
260
+
261
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
262
+ """
263
+ self.set_attn_processor(AttnProcessor())
264
+
265
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
266
+ def fuse_qkv_projections(self):
267
+ """
268
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
269
+ are fused. For cross-attention modules, key and value projection matrices are fused.
270
+
271
+ <Tip warning={true}>
272
+
273
+ This API is 🧪 experimental.
274
+
275
+ </Tip>
276
+ """
277
+ self.original_attn_processors = None
278
+
279
+ for _, attn_processor in self.attn_processors.items():
280
+ if "Added" in str(attn_processor.__class__.__name__):
281
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
282
+
283
+ self.original_attn_processors = self.attn_processors
284
+
285
+ for module in self.modules():
286
+ if isinstance(module, Attention):
287
+ module.fuse_projections(fuse=True)
288
+
289
+ self.set_attn_processor(FusedAttnProcessor2_0())
290
+
291
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
292
+ def unfuse_qkv_projections(self):
293
+ """Disables the fused QKV projection if enabled.
294
+
295
+ <Tip warning={true}>
296
+
297
+ This API is 🧪 experimental.
298
+
299
+ </Tip>
300
+
301
+ """
302
+ if self.original_attn_processors is not None:
303
+ self.set_attn_processor(self.original_attn_processors)
304
+
305
+ def forward(
306
+ self,
307
+ hidden_states: torch.Tensor,
308
+ encoder_hidden_states: torch.Tensor,
309
+ encoder_attention_mask: torch.Tensor,
310
+ ip_hidden_states: torch.Tensor = None,
311
+ ip_attention_mask: torch.Tensor = None,
312
+ text_bboxes = None,
313
+ character_bboxes = None,
314
+ reference_embeddings = None,
315
+ cfg_on_10_percent = False,
316
+ timestep: Optional[torch.LongTensor] = None,
317
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
318
+ cross_attention_kwargs: Dict[str, Any] = None,
319
+ return_dict: bool = True,
320
+ ):
321
+ """
322
+ The [`PixArtTransformer2DModel`] forward method.
323
+
324
+ Args:
325
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
326
+ Input `hidden_states`.
327
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
328
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
329
+ self-attention.
330
+ timestep (`torch.LongTensor`, *optional*):
331
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
332
+ added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
333
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
334
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
335
+ `self.processor` in
336
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
337
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
338
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
339
+
340
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
341
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
342
+
343
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
344
+ above. This bias will be added to the cross-attention scores.
345
+ return_dict (`bool`, *optional*, defaults to `True`):
346
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
347
+ tuple.
348
+
349
+ Returns:
350
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
351
+ `tuple` where the first element is the sample tensor.
352
+ """
353
+ if self.use_additional_conditions and added_cond_kwargs is None:
354
+ raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
355
+ # 0. Prompt Embedding Modification
356
+ assert (ip_hidden_states is None) ^ (text_bboxes is None and character_bboxes is None and reference_embeddings is None)
357
+ if ip_hidden_states is None:
358
+ ip_hidden_states, ip_attention_mask = self.ip_adapter(text_bboxes, character_bboxes, reference_embeddings, cfg_on_10_percent)
359
+
360
+ # 1. Input
361
+ batch_size = len(hidden_states)
362
+ heights = [h.shape[-2] // self.config.patch_size for h in hidden_states]
363
+ widths = [w.shape[-1] // self.config.patch_size for w in hidden_states]
364
+ hidden_states = [self.pos_embed(hs[None])[0] for hs in hidden_states]
365
+ attention_mask = [torch.ones(x.shape[0]) for x in hidden_states]
366
+ hidden_states = pad_sequence(hidden_states, batch_first=True)
367
+ attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0).bool().to(hidden_states.device)
368
+ original_attention_mask = attention_mask
369
+
370
+ timestep, embedded_timestep = self.adaln_single(
371
+ timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
372
+ )
373
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
374
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
375
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
376
+ # expects mask of shape:
377
+ # [batch, key_tokens]
378
+ # adds singleton query_tokens dimension:
379
+ # [batch, 1, key_tokens]
380
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
381
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
382
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
383
+ if attention_mask is not None and attention_mask.ndim == 2:
384
+ # assume that mask is expressed as:
385
+ # (1 = keep, 0 = discard)
386
+ # convert mask into a bias that can be added to attention scores:
387
+ # (keep = +0, discard = -10000.0)
388
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
389
+ attention_mask = attention_mask.unsqueeze(1)
390
+
391
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
392
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
393
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
394
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
395
+
396
+ if self.caption_projection is not None:
397
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
398
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
399
+
400
+ # 2. Blocks
401
+ for block in self.transformer_blocks:
402
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
403
+
404
+ def create_custom_forward(module, return_dict=None):
405
+ def custom_forward(*inputs):
406
+ if return_dict is not None:
407
+ return module(*inputs, return_dict=return_dict)
408
+ else:
409
+ return module(*inputs)
410
+
411
+ return custom_forward
412
+
413
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
414
+ hidden_states = torch.utils.checkpoint.checkpoint(
415
+ create_custom_forward(block),
416
+ hidden_states,
417
+ attention_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ ip_hidden_states,
421
+ ip_attention_mask,
422
+ timestep,
423
+ cross_attention_kwargs,
424
+ None,
425
+ **ckpt_kwargs,
426
+ )
427
+ else:
428
+ hidden_states = block(
429
+ hidden_states,
430
+ attention_mask=attention_mask,
431
+ encoder_hidden_states=encoder_hidden_states,
432
+ encoder_attention_mask=encoder_attention_mask,
433
+ ip_hidden_states=ip_hidden_states,
434
+ ip_attention_mask=ip_attention_mask,
435
+ timestep=timestep,
436
+ cross_attention_kwargs=cross_attention_kwargs,
437
+ class_labels=None,
438
+ )
439
+
440
+ # 3. Output
441
+ shift, scale = (
442
+ self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
443
+ ).chunk(2, dim=1)
444
+ hidden_states = self.norm_out(hidden_states)
445
+ # Modulation
446
+ hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
447
+ hidden_states = self.proj_out(hidden_states)
448
+ hidden_states = hidden_states.squeeze(1)
449
+
450
+ # unpatchify
451
+ outputs = []
452
+ for idx, (height, width) in enumerate(zip(heights, widths)):
453
+ _hidden_state = hidden_states[idx][original_attention_mask[idx]].reshape(
454
+ shape=(height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
455
+ )
456
+ _hidden_state = torch.einsum("hwpqc->chpwq", _hidden_state)
457
+ outputs.append(_hidden_state.reshape(
458
+ shape=(self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
459
+ ))
460
+
461
+ if len(set([x.shape for x in outputs])) == 1:
462
+ outputs = torch.stack(outputs)
463
+
464
+ if not return_dict:
465
+ return (outputs,)
466
+
467
+ return Transformer2DModelOutput(sample=outputs)
468
+
469
+
470
+ class RBFEmbedding(nn.Module):
471
+ def __init__(self, output_dim, num_kernels=32):
472
+ super().__init__()
473
+ self.means = nn.Parameter(torch.linspace(0, 1, num_kernels))
474
+ self.scales = nn.Parameter(torch.ones(num_kernels) * 20)
475
+ self.proj = nn.Linear(num_kernels * 4, output_dim)
476
+
477
+ def forward(self, box):
478
+ box = torch.tensor(box, dtype=self.means.dtype, device=self.means.device)
479
+ x = box.unsqueeze(-1) - self.means
480
+ x = torch.exp(-0.5 * (x * self.scales.unsqueeze(0)) ** 2)
481
+ x = x.reshape(-1)
482
+ return self.proj(x)
483
+
484
+ def participate_in_grad(self):
485
+ return self.proj.weight.sum() + self.proj.bias.sum() + self.means.sum() + self.scales.sum()
486
+
487
+ class RoPEPositionalEmbedding(nn.Module):
488
+ def __init__(self, embedding_dim, base=10000):
489
+ super().__init__()
490
+ self.embedding_dim = embedding_dim
491
+ assert embedding_dim % 2 == 0, "Embedding dimension must be even"
492
+ half_dim = embedding_dim // 2
493
+ freqs = 1.0 / (base ** (torch.arange(0, half_dim).float() / half_dim))
494
+ self.register_buffer("freqs", freqs)
495
+
496
+ def forward(self, x, positions):
497
+ orig_dtype = x.dtype
498
+ x = x.float()
499
+ positions = positions.float()
500
+ x_2d = rearrange(x, '... (d two) -> ... d two', two=2) # [..., dim/2, 2]
501
+ positions = positions.unsqueeze(-1) * self.freqs.float() # [seq_len, dim/2]
502
+ sin = positions.sin().unsqueeze(-1) # [seq_len, dim/2, 1]
503
+ cos = positions.cos().unsqueeze(-1) # [seq_len, dim/2, 1]
504
+ x_out = torch.cat([
505
+ x_2d[..., 0:1] * cos - x_2d[..., 1:2] * sin,
506
+ x_2d[..., 0:1] * sin + x_2d[..., 1:2] * cos,
507
+ ], dim=-1)
508
+ output = rearrange(x_out, '... d two -> ... (d two)')
509
+ return output.to(orig_dtype)
510
+
511
+ class IPAdapter(ModelMixin):
512
+ def __init__(self):
513
+ super().__init__()
514
+ self.embedding_dim = 1152
515
+ self.box_embedding = RBFEmbedding(self.embedding_dim)
516
+ self.pos_embedding = RoPEPositionalEmbedding(self.embedding_dim)
517
+ self.text_cls_embedding = nn.Embedding(1, self.embedding_dim)
518
+ self.character_cls_embedding = nn.Embedding(4, self.embedding_dim)
519
+ self.ref_embedding_proj = nn.Linear(768, 4 * self.embedding_dim)
520
+ self.void_ip_embed = nn.Embedding(1, self.embedding_dim)
521
+ self.negative_ip_embed = nn.Embedding(1, self.embedding_dim)
522
+ self.norm = nn.LayerNorm(self.embedding_dim)
523
+
524
+ def participate_in_grad(self):
525
+ return sum([
526
+ self.box_embedding.participate_in_grad(),
527
+ self.text_cls_embedding.weight.sum(),
528
+ self.character_cls_embedding.weight.sum(),
529
+ self.ref_embedding_proj.weight.sum(),
530
+ self.ref_embedding_proj.bias.sum(),
531
+ self.void_ip_embed.weight.sum(),
532
+ self.negative_ip_embed.weight.sum(),
533
+ self.norm.weight.sum(),
534
+ self.norm.bias.sum()
535
+ ])
536
+
537
+ def embed_text(self, box):
538
+ box_embedding = self.box_embedding(box)
539
+ return torch.stack([
540
+ box_embedding,
541
+ *self.text_cls_embedding.weight,
542
+ ])
543
+
544
+ def embed_character(self, character_bbox, reference_embedding):
545
+ box_embedding = self.box_embedding(character_bbox)
546
+ if reference_embedding is None:
547
+ character_embedding = self.character_cls_embedding.weight
548
+ else:
549
+ character_embedding = self.ref_embedding_proj(reference_embedding.unsqueeze(0))
550
+ character_embedding = rearrange(character_embedding, "1 (c h) -> h c", h=4)
551
+ return torch.stack([
552
+ box_embedding,
553
+ *character_embedding
554
+ ])
555
+
556
+ def apply_position_embedding(self, embeddings):
557
+ seq_length = embeddings.shape[0]
558
+ positions = torch.arange(seq_length, device=embeddings.device, dtype=embeddings.dtype)
559
+ return self.pos_embedding(embeddings, positions)
560
+
561
+ def forward(self, batch_text_bboxes, batch_character_bboxes, batch_reference_embeddings, cfg_on_10_percent):
562
+ ip_embeddings = []
563
+ for batch_idx, (text_bboxes, character_bboxes, reference_embeddings) in enumerate(zip(batch_text_bboxes, batch_character_bboxes, batch_reference_embeddings)):
564
+ text_embeddings = [self.embed_text(box) for box in text_bboxes]
565
+ character_embeddings = [self.embed_character(box, reference_embeddings[i]) for i, box in enumerate(character_bboxes)]
566
+ if len(text_embeddings) + len(character_embeddings) == 0:
567
+ ip_embeddings.append(self.void_ip_embed.weight)
568
+ continue
569
+ ip_embedding = torch.cat(text_embeddings + character_embeddings, dim=0)
570
+ ip_embeddings.append(self.apply_position_embedding(ip_embedding))
571
+
572
+ ip_mask = [torch.ones(x.shape[0], dtype=torch.bool, device=x.device) for x in ip_embeddings]
573
+ ip_embeddings = pad_sequence(ip_embeddings, batch_first=True, padding_value=0)
574
+ ip_mask = pad_sequence(ip_mask, batch_first=True, padding_value=0).bool()
575
+ if cfg_on_10_percent:
576
+ last_10_percent = int(len(ip_embeddings) * 0.1)
577
+ ip_embeddings[-last_10_percent:] = self.negative_ip_embed.weight
578
+ ip_mask[-last_10_percent:] = 0
579
+ ip_mask[-last_10_percent:, :1] = 1
580
+ return self.norm(ip_embeddings), ip_mask
581
+
582
+
583
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
584
+ # "feed_forward_chunk_size" can be used to save memory
585
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
586
+ raise ValueError(
587
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
588
+ )
589
+
590
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
591
+ ff_output = torch.cat(
592
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
593
+ dim=chunk_dim,
594
+ )
595
+ return ff_output
596
+
597
+
598
+ @maybe_allow_in_graph
599
+ class GatedSelfAttentionDense(nn.Module):
600
+ r"""
601
+ A gated self-attention dense layer that combines visual features and object features.
602
+
603
+ Parameters:
604
+ query_dim (`int`): The number of channels in the query.
605
+ context_dim (`int`): The number of channels in the context.
606
+ n_heads (`int`): The number of heads to use for attention.
607
+ d_head (`int`): The number of channels in each head.
608
+ """
609
+
610
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
611
+ super().__init__()
612
+
613
+ # we need a linear projection since we need cat visual feature and obj feature
614
+ self.linear = nn.Linear(context_dim, query_dim)
615
+
616
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
617
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
618
+
619
+ self.norm1 = nn.LayerNorm(query_dim)
620
+ self.norm2 = nn.LayerNorm(query_dim)
621
+
622
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
623
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
624
+
625
+ self.enabled = True
626
+
627
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
628
+ if not self.enabled:
629
+ return x
630
+
631
+ n_visual = x.shape[1]
632
+ objs = self.linear(objs)
633
+
634
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
635
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
636
+
637
+ return x
638
+
639
+
640
+ @maybe_allow_in_graph
641
+ class BasicTransformerBlock(nn.Module):
642
+ r"""
643
+ A basic Transformer block.
644
+
645
+ Parameters:
646
+ dim (`int`): The number of channels in the input and output.
647
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
648
+ attention_head_dim (`int`): The number of channels in each head.
649
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
650
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
651
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
652
+ num_embeds_ada_norm (:
653
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
654
+ attention_bias (:
655
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
656
+ only_cross_attention (`bool`, *optional*):
657
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
658
+ double_self_attention (`bool`, *optional*):
659
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
660
+ upcast_attention (`bool`, *optional*):
661
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
662
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
663
+ Whether to use learnable elementwise affine parameters for normalization.
664
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
665
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
666
+ final_dropout (`bool` *optional*, defaults to False):
667
+ Whether to apply a final dropout after the last feed-forward layer.
668
+ attention_type (`str`, *optional*, defaults to `"default"`):
669
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
670
+ positional_embeddings (`str`, *optional*, defaults to `None`):
671
+ The type of positional embeddings to apply to.
672
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
673
+ The maximum number of positional embeddings to apply.
674
+ """
675
+
676
+ def __init__(
677
+ self,
678
+ dim: int,
679
+ num_attention_heads: int,
680
+ attention_head_dim: int,
681
+ dropout=0.0,
682
+ cross_attention_dim: Optional[int] = None,
683
+ activation_fn: str = "geglu",
684
+ num_embeds_ada_norm: Optional[int] = None,
685
+ attention_bias: bool = False,
686
+ only_cross_attention: bool = False,
687
+ double_self_attention: bool = False,
688
+ upcast_attention: bool = False,
689
+ norm_elementwise_affine: bool = True,
690
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
691
+ norm_eps: float = 1e-5,
692
+ final_dropout: bool = False,
693
+ attention_type: str = "default",
694
+ positional_embeddings: Optional[str] = None,
695
+ num_positional_embeddings: Optional[int] = None,
696
+ ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
697
+ ada_norm_bias: Optional[int] = None,
698
+ ff_inner_dim: Optional[int] = None,
699
+ ff_bias: bool = True,
700
+ attention_out_bias: bool = True,
701
+ ):
702
+ super().__init__()
703
+ self.dim = dim
704
+ self.num_attention_heads = num_attention_heads
705
+ self.attention_head_dim = attention_head_dim
706
+ self.dropout = dropout
707
+ self.cross_attention_dim = cross_attention_dim
708
+ self.activation_fn = activation_fn
709
+ self.attention_bias = attention_bias
710
+ self.double_self_attention = double_self_attention
711
+ self.norm_elementwise_affine = norm_elementwise_affine
712
+ self.positional_embeddings = positional_embeddings
713
+ self.num_positional_embeddings = num_positional_embeddings
714
+ self.only_cross_attention = only_cross_attention
715
+
716
+ # Define 3 blocks. Each block has its own normalization layer.
717
+ # 1. Self-Attn
718
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
719
+ self.attn1 = Attention(
720
+ query_dim=dim,
721
+ heads=num_attention_heads,
722
+ dim_head=attention_head_dim,
723
+ dropout=dropout,
724
+ bias=attention_bias,
725
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
726
+ upcast_attention=upcast_attention,
727
+ out_bias=attention_out_bias,
728
+ )
729
+
730
+ # 2. Cross-Attn
731
+ self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
732
+ self.attn2 = Attention(
733
+ query_dim=dim,
734
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
735
+ heads=num_attention_heads,
736
+ dim_head=attention_head_dim,
737
+ dropout=dropout,
738
+ bias=attention_bias,
739
+ upcast_attention=upcast_attention,
740
+ out_bias=attention_out_bias,
741
+ )
742
+
743
+ self.ip_attn = Attention(
744
+ query_dim=dim,
745
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
746
+ heads=num_attention_heads,
747
+ dim_head=attention_head_dim,
748
+ dropout=dropout,
749
+ bias=attention_bias,
750
+ upcast_attention=upcast_attention,
751
+ out_bias=attention_out_bias,
752
+ )
753
+ self.ip_attn.to_out[0].weight.data.zero_()
754
+ self.ip_attn.to_out[0].bias.data.zero_()
755
+
756
+ # 3. Feed-forward
757
+ self.ff = FeedForward(
758
+ dim,
759
+ dropout=dropout,
760
+ activation_fn=activation_fn,
761
+ final_dropout=final_dropout,
762
+ inner_dim=ff_inner_dim,
763
+ bias=ff_bias,
764
+ )
765
+
766
+ # 5. Scale-shift for PixArt-Alpha.
767
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
768
+
769
+ # let chunk size default to None
770
+ self._chunk_size = None
771
+ self._chunk_dim = 0
772
+
773
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
774
+ # Sets chunk feed-forward
775
+ self._chunk_size = chunk_size
776
+ self._chunk_dim = dim
777
+
778
+ def forward(
779
+ self,
780
+ hidden_states: torch.Tensor,
781
+ attention_mask: Optional[torch.Tensor] = None,
782
+ encoder_hidden_states: Optional[torch.Tensor] = None,
783
+ encoder_attention_mask: Optional[torch.Tensor] = None,
784
+ ip_hidden_states: Optional[torch.Tensor] = None,
785
+ ip_attention_mask: Optional[torch.Tensor] = None,
786
+ timestep: Optional[torch.LongTensor] = None,
787
+ cross_attention_kwargs: Dict[str, Any] = None,
788
+ class_labels: Optional[torch.LongTensor] = None,
789
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
790
+ ) -> torch.Tensor:
791
+ if cross_attention_kwargs is not None:
792
+ if cross_attention_kwargs.get("scale", None) is not None:
793
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
794
+
795
+ # Notice that normalization is always applied before the real computation in the following blocks.
796
+ # 0. Self-Attention
797
+ batch_size = hidden_states.shape[0]
798
+
799
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
800
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
801
+ ).chunk(6, dim=1)
802
+ norm_hidden_states = self.norm1(hidden_states)
803
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
804
+
805
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
806
+
807
+ attn_output = self.attn1(
808
+ norm_hidden_states,
809
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
810
+ attention_mask=attention_mask,
811
+ **cross_attention_kwargs,
812
+ )
813
+
814
+ attn_output = gate_msa * attn_output
815
+
816
+ hidden_states = attn_output + hidden_states
817
+ if hidden_states.ndim == 4:
818
+ hidden_states = hidden_states.squeeze(1)
819
+
820
+ # 3. Cross-Attention
821
+ attn_output = self.attn2(
822
+ hidden_states,
823
+ encoder_hidden_states=encoder_hidden_states,
824
+ attention_mask=encoder_attention_mask,
825
+ **cross_attention_kwargs,
826
+ )
827
+ ip_attn_output = self.ip_attn(
828
+ hidden_states,
829
+ encoder_hidden_states=ip_hidden_states,
830
+ attention_mask=ip_attention_mask,
831
+ **cross_attention_kwargs,
832
+ )
833
+ hidden_states = attn_output + ip_attn_output + hidden_states
834
+
835
+ # 4. Feed-forward
836
+ norm_hidden_states = self.norm2(hidden_states)
837
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
838
+
839
+ if self._chunk_size is not None:
840
+ # "feed_forward_chunk_size" can be used to save memory
841
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
842
+ else:
843
+ ff_output = self.ff(norm_hidden_states)
844
+
845
+ ff_output = gate_mlp * ff_output
846
+
847
+ hidden_states = ff_output + hidden_states
848
+ if hidden_states.ndim == 4:
849
+ hidden_states = hidden_states.squeeze(1)
850
+
851
+ return hidden_states
852
+
853
+ class FeedForward(nn.Module):
854
+ r"""
855
+ A feed-forward layer.
856
+
857
+ Parameters:
858
+ dim (`int`): The number of channels in the input.
859
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
860
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
861
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
862
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
863
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
864
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
865
+ """
866
+
867
+ def __init__(
868
+ self,
869
+ dim: int,
870
+ dim_out: Optional[int] = None,
871
+ mult: int = 4,
872
+ dropout: float = 0.0,
873
+ activation_fn: str = "geglu",
874
+ final_dropout: bool = False,
875
+ inner_dim=None,
876
+ bias: bool = True,
877
+ ):
878
+ super().__init__()
879
+ if inner_dim is None:
880
+ inner_dim = int(dim * mult)
881
+ dim_out = dim_out if dim_out is not None else dim
882
+
883
+ if activation_fn == "gelu":
884
+ act_fn = GELU(dim, inner_dim, bias=bias)
885
+ if activation_fn == "gelu-approximate":
886
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
887
+ elif activation_fn == "geglu":
888
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
889
+ elif activation_fn == "geglu-approximate":
890
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
891
+ elif activation_fn == "swiglu":
892
+ act_fn = SwiGLU(dim, inner_dim, bias=bias)
893
+ elif activation_fn == "linear-silu":
894
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
895
+
896
+ self.net = nn.ModuleList([])
897
+ # project in
898
+ self.net.append(act_fn)
899
+ # project dropout
900
+ self.net.append(nn.Dropout(dropout))
901
+ # project out
902
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
903
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
904
+ if final_dropout:
905
+ self.net.append(nn.Dropout(dropout))
906
+
907
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
908
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
909
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
910
+ deprecate("scale", "1.0.0", deprecation_message)
911
+ for module in self.net:
912
+ hidden_states = module(hidden_states)
913
+ return hidden_states