AlexGraikos commited on
Commit
ea72b80
·
verified ·
1 Parent(s): 0a9f8f3

Delete pixcell_transformer_2d.py

Browse files
Files changed (1) hide show
  1. pixcell_transformer_2d.py +0 -676
pixcell_transformer_2d.py DELETED
@@ -1,676 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, Optional, Union
15
-
16
- import torch
17
- from torch import nn
18
-
19
- from diffusers.configuration_utils import ConfigMixin, register_to_config
20
- from diffusers.utils import is_torch_version, logging
21
- from diffusers.models.attention import BasicTransformerBlock
22
- from diffusers.models.attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
- from diffusers.models.embeddings import PatchEmbed
24
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
25
- from diffusers.models.modeling_utils import ModelMixin
26
- from diffusers.models.normalization import AdaLayerNormSingle
27
- from diffusers.models.activations import deprecate, FP32SiLU
28
-
29
-
30
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
-
32
-
33
- # PixCell UNI conditioning
34
- def pixcell_get_2d_sincos_pos_embed(
35
- embed_dim,
36
- grid_size,
37
- cls_token=False,
38
- extra_tokens=0,
39
- interpolation_scale=1.0,
40
- base_size=16,
41
- device: Optional[torch.device] = None,
42
- phase=0,
43
- output_type: str = "np",
44
- ):
45
- """
46
- Creates 2D sinusoidal positional embeddings.
47
-
48
- Args:
49
- embed_dim (`int`):
50
- The embedding dimension.
51
- grid_size (`int`):
52
- The size of the grid height and width.
53
- cls_token (`bool`, defaults to `False`):
54
- Whether or not to add a classification token.
55
- extra_tokens (`int`, defaults to `0`):
56
- The number of extra tokens to add.
57
- interpolation_scale (`float`, defaults to `1.0`):
58
- The scale of the interpolation.
59
-
60
- Returns:
61
- pos_embed (`torch.Tensor`):
62
- Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
63
- embed_dim]` if using cls_token
64
- """
65
- if output_type == "np":
66
- deprecation_message = (
67
- "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
68
- " `from_numpy` is no longer required."
69
- " Pass `output_type='pt' to use the new version now."
70
- )
71
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
72
- raise ValueError("Not supported")
73
- if isinstance(grid_size, int):
74
- grid_size = (grid_size, grid_size)
75
-
76
- grid_h = (
77
- torch.arange(grid_size[0], device=device, dtype=torch.float32)
78
- / (grid_size[0] / base_size)
79
- / interpolation_scale
80
- )
81
- grid_w = (
82
- torch.arange(grid_size[1], device=device, dtype=torch.float32)
83
- / (grid_size[1] / base_size)
84
- / interpolation_scale
85
- )
86
- grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
87
- grid = torch.stack(grid, dim=0)
88
-
89
- grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
90
- pos_embed = pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=phase, output_type=output_type)
91
- if cls_token and extra_tokens > 0:
92
- pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
93
- return pos_embed
94
-
95
-
96
- def pixcell_get_2d_sincos_pos_embed_from_grid(embed_dim, grid, phase=0, output_type="np"):
97
- r"""
98
- This function generates 2D sinusoidal positional embeddings from a grid.
99
-
100
- Args:
101
- embed_dim (`int`): The embedding dimension.
102
- grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
103
-
104
- Returns:
105
- `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
106
- """
107
- if output_type == "np":
108
- deprecation_message = (
109
- "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
110
- " `from_numpy` is no longer required."
111
- " Pass `output_type='pt' to use the new version now."
112
- )
113
- deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
114
- raise ValueError("Not supported")
115
- if embed_dim % 2 != 0:
116
- raise ValueError("embed_dim must be divisible by 2")
117
-
118
- # use half of dimensions to encode grid_h
119
- emb_h = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], phase=phase, output_type=output_type) # (H*W, D/2)
120
- emb_w = pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], phase=phase, output_type=output_type) # (H*W, D/2)
121
-
122
- emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
123
- return emb
124
-
125
-
126
- def pixcell_get_1d_sincos_pos_embed_from_grid(embed_dim, pos, phase=0, output_type="np"):
127
- """
128
- This function generates 1D positional embeddings from a grid.
129
-
130
- Args:
131
- embed_dim (`int`): The embedding dimension `D`
132
- pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
133
-
134
- Returns:
135
- `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
136
- """
137
- if output_type == "np":
138
- deprecation_message = (
139
- "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
140
- " `from_numpy` is no longer required."
141
- " Pass `output_type='pt' to use the new version now."
142
- )
143
- deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
144
- raise ValueError("Not supported")
145
- if embed_dim % 2 != 0:
146
- raise ValueError("embed_dim must be divisible by 2")
147
-
148
- omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
149
- omega /= embed_dim / 2.0
150
- omega = 1.0 / 10000**omega # (D/2,)
151
-
152
- pos = pos.reshape(-1) + phase # (M,)
153
- out = torch.outer(pos, omega) # (M, D/2), outer product
154
-
155
- emb_sin = torch.sin(out) # (M, D/2)
156
- emb_cos = torch.cos(out) # (M, D/2)
157
-
158
- emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
159
- return emb
160
-
161
-
162
- class PixcellUNIProjection(nn.Module):
163
- """
164
- Projects UNI embeddings. Also handles dropout for classifier-free guidance.
165
-
166
- Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
167
- """
168
-
169
- def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", num_tokens=1):
170
- super().__init__()
171
- if out_features is None:
172
- out_features = hidden_size
173
- self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
174
- if act_fn == "gelu_tanh":
175
- self.act_1 = nn.GELU(approximate="tanh")
176
- elif act_fn == "silu":
177
- self.act_1 = nn.SiLU()
178
- elif act_fn == "silu_fp32":
179
- self.act_1 = FP32SiLU()
180
- else:
181
- raise ValueError(f"Unknown activation function: {act_fn}")
182
- self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
183
-
184
- self.register_buffer("uncond_embedding", nn.Parameter(torch.randn(num_tokens, in_features) / in_features ** 0.5))
185
-
186
- def forward(self, caption):
187
- hidden_states = self.linear_1(caption)
188
- hidden_states = self.act_1(hidden_states)
189
- hidden_states = self.linear_2(hidden_states)
190
- return hidden_states
191
-
192
- class UNIPosEmbed(nn.Module):
193
- """
194
- Adds positional embeddings to the UNI conditions.
195
-
196
- Args:
197
- height (`int`, defaults to `224`): The height of the image.
198
- width (`int`, defaults to `224`): The width of the image.
199
- patch_size (`int`, defaults to `16`): The size of the patches.
200
- in_channels (`int`, defaults to `3`): The number of input channels.
201
- embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
202
- layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
203
- flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
204
- bias (`bool`, defaults to `True`): Whether or not to use bias.
205
- interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
206
- pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
207
- pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
208
- """
209
-
210
- def __init__(
211
- self,
212
- height=1,
213
- width=1,
214
- base_size=16,
215
- embed_dim=768,
216
- interpolation_scale=1,
217
- pos_embed_type="sincos",
218
- ):
219
- super().__init__()
220
-
221
- num_embeds = height*width
222
- grid_size = int(num_embeds ** 0.5)
223
-
224
- if pos_embed_type == "sincos":
225
- y_pos_embed = pixcell_get_2d_sincos_pos_embed(
226
- embed_dim,
227
- grid_size,
228
- base_size=base_size,
229
- interpolation_scale=interpolation_scale,
230
- output_type="pt",
231
- phase = base_size // num_embeds
232
- )
233
- self.register_buffer("y_pos_embed", y_pos_embed.float().unsqueeze(0))
234
- else:
235
- raise ValueError("`pos_embed_type` not supported")
236
-
237
- def forward(self, uni_embeds):
238
- return (uni_embeds + self.y_pos_embed).to(uni_embeds.dtype)
239
-
240
-
241
-
242
- class PixCellTransformer2DModel(ModelMixin, ConfigMixin):
243
- r"""
244
- A 2D Transformer model as introduced in PixArt family of models (https://arxiv.org/abs/2310.00426,
245
- https://arxiv.org/abs/2403.04692). Modified for the pathology domain.
246
-
247
- Parameters:
248
- num_attention_heads (int, optional, defaults to 16): The number of heads to use for multi-head attention.
249
- attention_head_dim (int, optional, defaults to 72): The number of channels in each head.
250
- in_channels (int, defaults to 4): The number of channels in the input.
251
- out_channels (int, optional):
252
- The number of channels in the output. Specify this parameter if the output channel number differs from the
253
- input.
254
- num_layers (int, optional, defaults to 28): The number of layers of Transformer blocks to use.
255
- dropout (float, optional, defaults to 0.0): The dropout probability to use within the Transformer blocks.
256
- norm_num_groups (int, optional, defaults to 32):
257
- Number of groups for group normalization within Transformer blocks.
258
- cross_attention_dim (int, optional):
259
- The dimensionality for cross-attention layers, typically matching the encoder's hidden dimension.
260
- attention_bias (bool, optional, defaults to True):
261
- Configure if the Transformer blocks' attention should contain a bias parameter.
262
- sample_size (int, defaults to 128):
263
- The width of the latent images. This parameter is fixed during training.
264
- patch_size (int, defaults to 2):
265
- Size of the patches the model processes, relevant for architectures working on non-sequential data.
266
- activation_fn (str, optional, defaults to "gelu-approximate"):
267
- Activation function to use in feed-forward networks within Transformer blocks.
268
- num_embeds_ada_norm (int, optional, defaults to 1000):
269
- Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
270
- inference.
271
- upcast_attention (bool, optional, defaults to False):
272
- If true, upcasts the attention mechanism dimensions for potentially improved performance.
273
- norm_type (str, optional, defaults to "ada_norm_zero"):
274
- Specifies the type of normalization used, can be 'ada_norm_zero'.
275
- norm_elementwise_affine (bool, optional, defaults to False):
276
- If true, enables element-wise affine parameters in the normalization layers.
277
- norm_eps (float, optional, defaults to 1e-6):
278
- A small constant added to the denominator in normalization layers to prevent division by zero.
279
- interpolation_scale (int, optional): Scale factor to use during interpolating the position embeddings.
280
- use_additional_conditions (bool, optional): If we're using additional conditions as inputs.
281
- attention_type (str, optional, defaults to "default"): Kind of attention mechanism to be used.
282
- caption_channels (int, optional, defaults to None):
283
- Number of channels to use for projecting the caption embeddings.
284
- use_linear_projection (bool, optional, defaults to False):
285
- Deprecated argument. Will be removed in a future version.
286
- num_vector_embeds (bool, optional, defaults to False):
287
- Deprecated argument. Will be removed in a future version.
288
- """
289
-
290
- _supports_gradient_checkpointing = True
291
- _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
292
-
293
- @register_to_config
294
- def __init__(
295
- self,
296
- num_attention_heads: int = 16,
297
- attention_head_dim: int = 72,
298
- in_channels: int = 4,
299
- out_channels: Optional[int] = 8,
300
- num_layers: int = 28,
301
- dropout: float = 0.0,
302
- norm_num_groups: int = 32,
303
- cross_attention_dim: Optional[int] = 1152,
304
- attention_bias: bool = True,
305
- sample_size: int = 128,
306
- patch_size: int = 2,
307
- activation_fn: str = "gelu-approximate",
308
- num_embeds_ada_norm: Optional[int] = 1000,
309
- upcast_attention: bool = False,
310
- norm_type: str = "ada_norm_single",
311
- norm_elementwise_affine: bool = False,
312
- norm_eps: float = 1e-6,
313
- interpolation_scale: Optional[int] = None,
314
- use_additional_conditions: Optional[bool] = None,
315
- caption_channels: Optional[int] = None,
316
- caption_num_tokens: int = 1,
317
- attention_type: Optional[str] = "default",
318
- ):
319
- super().__init__()
320
-
321
- # Validate inputs.
322
- if norm_type != "ada_norm_single":
323
- raise NotImplementedError(
324
- f"Forward pass is not implemented when `patch_size` is not None and `norm_type` is '{norm_type}'."
325
- )
326
- elif norm_type == "ada_norm_single" and num_embeds_ada_norm is None:
327
- raise ValueError(
328
- f"When using a `patch_size` and this `norm_type` ({norm_type}), `num_embeds_ada_norm` cannot be None."
329
- )
330
-
331
- # Set some common variables used across the board.
332
- self.attention_head_dim = attention_head_dim
333
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
334
- self.out_channels = in_channels if out_channels is None else out_channels
335
- if use_additional_conditions is None:
336
- if sample_size == 128:
337
- use_additional_conditions = True
338
- else:
339
- use_additional_conditions = False
340
- self.use_additional_conditions = use_additional_conditions
341
-
342
- self.gradient_checkpointing = False
343
-
344
- # 2. Initialize the position embedding and transformer blocks.
345
- self.height = self.config.sample_size
346
- self.width = self.config.sample_size
347
-
348
- interpolation_scale = (
349
- self.config.interpolation_scale
350
- if self.config.interpolation_scale is not None
351
- else max(self.config.sample_size // 64, 1)
352
- )
353
- self.pos_embed = PatchEmbed(
354
- height=self.config.sample_size,
355
- width=self.config.sample_size,
356
- patch_size=self.config.patch_size,
357
- in_channels=self.config.in_channels,
358
- embed_dim=self.inner_dim,
359
- interpolation_scale=interpolation_scale,
360
- )
361
-
362
- self.transformer_blocks = nn.ModuleList(
363
- [
364
- BasicTransformerBlock(
365
- self.inner_dim,
366
- self.config.num_attention_heads,
367
- self.config.attention_head_dim,
368
- dropout=self.config.dropout,
369
- cross_attention_dim=self.config.cross_attention_dim,
370
- activation_fn=self.config.activation_fn,
371
- num_embeds_ada_norm=self.config.num_embeds_ada_norm,
372
- attention_bias=self.config.attention_bias,
373
- upcast_attention=self.config.upcast_attention,
374
- norm_type=norm_type,
375
- norm_elementwise_affine=self.config.norm_elementwise_affine,
376
- norm_eps=self.config.norm_eps,
377
- attention_type=self.config.attention_type,
378
- )
379
- for _ in range(self.config.num_layers)
380
- ]
381
- )
382
-
383
- # Initialize the positional embedding for the conditions for >1 UNI embeddings
384
- if self.config.caption_num_tokens == 1:
385
- self.y_pos_embed = None
386
- else:
387
- # 1:1 aspect ratio
388
- self.uni_height = int(self.config.caption_num_tokens ** 0.5)
389
- self.uni_width = int(self.config.caption_num_tokens ** 0.5)
390
-
391
- self.y_pos_embed = UNIPosEmbed(
392
- height=self.uni_height,
393
- width=self.uni_width,
394
- base_size=self.config.sample_size // self.config.patch_size,
395
- embed_dim=self.config.caption_channels,
396
- interpolation_scale=2, # Should this be fixed?
397
- pos_embed_type="sincos", # This is fixed
398
- )
399
-
400
- # 3. Output blocks.
401
- self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
402
- self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
403
- self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
404
-
405
- self.adaln_single = AdaLayerNormSingle(
406
- self.inner_dim, use_additional_conditions=self.use_additional_conditions
407
- )
408
- self.caption_projection = None
409
- if self.config.caption_channels is not None:
410
- self.caption_projection = PixcellUNIProjection(
411
- in_features=self.config.caption_channels, hidden_size=self.inner_dim, num_tokens=self.config.caption_num_tokens,
412
- )
413
-
414
- def _set_gradient_checkpointing(self, module, value=False):
415
- if hasattr(module, "gradient_checkpointing"):
416
- module.gradient_checkpointing = value
417
-
418
- @property
419
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
420
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
421
- r"""
422
- Returns:
423
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
424
- indexed by its weight name.
425
- """
426
- # set recursively
427
- processors = {}
428
-
429
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
430
- if hasattr(module, "get_processor"):
431
- processors[f"{name}.processor"] = module.get_processor()
432
-
433
- for sub_name, child in module.named_children():
434
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
435
-
436
- return processors
437
-
438
- for name, module in self.named_children():
439
- fn_recursive_add_processors(name, module, processors)
440
-
441
- return processors
442
-
443
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
444
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
445
- r"""
446
- Sets the attention processor to use to compute attention.
447
-
448
- Parameters:
449
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
450
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
451
- for **all** `Attention` layers.
452
-
453
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
454
- processor. This is strongly recommended when setting trainable attention processors.
455
-
456
- """
457
- count = len(self.attn_processors.keys())
458
-
459
- if isinstance(processor, dict) and len(processor) != count:
460
- raise ValueError(
461
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
462
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
463
- )
464
-
465
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
466
- if hasattr(module, "set_processor"):
467
- if not isinstance(processor, dict):
468
- module.set_processor(processor)
469
- else:
470
- module.set_processor(processor.pop(f"{name}.processor"))
471
-
472
- for sub_name, child in module.named_children():
473
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
474
-
475
- for name, module in self.named_children():
476
- fn_recursive_attn_processor(name, module, processor)
477
-
478
- def set_default_attn_processor(self):
479
- """
480
- Disables custom attention processors and sets the default attention implementation.
481
-
482
- Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
483
- """
484
- self.set_attn_processor(AttnProcessor())
485
-
486
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
487
- def fuse_qkv_projections(self):
488
- """
489
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
490
- are fused. For cross-attention modules, key and value projection matrices are fused.
491
-
492
- <Tip warning={true}>
493
-
494
- This API is 🧪 experimental.
495
-
496
- </Tip>
497
- """
498
- self.original_attn_processors = None
499
-
500
- for _, attn_processor in self.attn_processors.items():
501
- if "Added" in str(attn_processor.__class__.__name__):
502
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
503
-
504
- self.original_attn_processors = self.attn_processors
505
-
506
- for module in self.modules():
507
- if isinstance(module, Attention):
508
- module.fuse_projections(fuse=True)
509
-
510
- self.set_attn_processor(FusedAttnProcessor2_0())
511
-
512
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
513
- def unfuse_qkv_projections(self):
514
- """Disables the fused QKV projection if enabled.
515
-
516
- <Tip warning={true}>
517
-
518
- This API is 🧪 experimental.
519
-
520
- </Tip>
521
-
522
- """
523
- if self.original_attn_processors is not None:
524
- self.set_attn_processor(self.original_attn_processors)
525
-
526
- def forward(
527
- self,
528
- hidden_states: torch.Tensor,
529
- encoder_hidden_states: Optional[torch.Tensor] = None,
530
- timestep: Optional[torch.LongTensor] = None,
531
- added_cond_kwargs: Dict[str, torch.Tensor] = None,
532
- cross_attention_kwargs: Dict[str, Any] = None,
533
- attention_mask: Optional[torch.Tensor] = None,
534
- encoder_attention_mask: Optional[torch.Tensor] = None,
535
- return_dict: bool = True,
536
- ):
537
- """
538
- The [`PixCellTransformer2DModel`] forward method.
539
-
540
- Args:
541
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
542
- Input `hidden_states`.
543
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
544
- Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
545
- self-attention.
546
- timestep (`torch.LongTensor`, *optional*):
547
- Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
548
- added_cond_kwargs: (`Dict[str, Any]`, *optional*): Additional conditions to be used as inputs.
549
- cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
550
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
551
- `self.processor` in
552
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
553
- attention_mask ( `torch.Tensor`, *optional*):
554
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
555
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
556
- negative values to the attention scores corresponding to "discard" tokens.
557
- encoder_attention_mask ( `torch.Tensor`, *optional*):
558
- Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
559
-
560
- * Mask `(batch, sequence_length)` True = keep, False = discard.
561
- * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
562
-
563
- If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
564
- above. This bias will be added to the cross-attention scores.
565
- return_dict (`bool`, *optional*, defaults to `True`):
566
- Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
567
- tuple.
568
-
569
- Returns:
570
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
571
- `tuple` where the first element is the sample tensor.
572
- """
573
- if self.use_additional_conditions and added_cond_kwargs is None:
574
- raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
575
-
576
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
577
- # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
578
- # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
579
- # expects mask of shape:
580
- # [batch, key_tokens]
581
- # adds singleton query_tokens dimension:
582
- # [batch, 1, key_tokens]
583
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
584
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
585
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
586
- if attention_mask is not None and attention_mask.ndim == 2:
587
- # assume that mask is expressed as:
588
- # (1 = keep, 0 = discard)
589
- # convert mask into a bias that can be added to attention scores:
590
- # (keep = +0, discard = -10000.0)
591
- attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
592
- attention_mask = attention_mask.unsqueeze(1)
593
-
594
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
595
- if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
596
- encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
597
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
598
-
599
- # 1. Input
600
- batch_size = hidden_states.shape[0]
601
- height, width = (
602
- hidden_states.shape[-2] // self.config.patch_size,
603
- hidden_states.shape[-1] // self.config.patch_size,
604
- )
605
- hidden_states = self.pos_embed(hidden_states)
606
-
607
- timestep, embedded_timestep = self.adaln_single(
608
- timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
609
- )
610
-
611
- if self.caption_projection is not None:
612
- # Add positional embeddings to conditions if >1 UNI are given
613
- if self.y_pos_embed is not None:
614
- encoder_hidden_states = self.y_pos_embed(encoder_hidden_states)
615
- encoder_hidden_states = self.caption_projection(encoder_hidden_states)
616
- encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
617
-
618
- # 2. Blocks
619
- for block in self.transformer_blocks:
620
- if torch.is_grad_enabled() and self.gradient_checkpointing:
621
-
622
- def create_custom_forward(module, return_dict=None):
623
- def custom_forward(*inputs):
624
- if return_dict is not None:
625
- return module(*inputs, return_dict=return_dict)
626
- else:
627
- return module(*inputs)
628
-
629
- return custom_forward
630
-
631
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
632
- hidden_states = torch.utils.checkpoint.checkpoint(
633
- create_custom_forward(block),
634
- hidden_states,
635
- attention_mask,
636
- encoder_hidden_states,
637
- encoder_attention_mask,
638
- timestep,
639
- cross_attention_kwargs,
640
- None,
641
- **ckpt_kwargs,
642
- )
643
- else:
644
- hidden_states = block(
645
- hidden_states,
646
- attention_mask=attention_mask,
647
- encoder_hidden_states=encoder_hidden_states,
648
- encoder_attention_mask=encoder_attention_mask,
649
- timestep=timestep,
650
- cross_attention_kwargs=cross_attention_kwargs,
651
- class_labels=None,
652
- )
653
-
654
- # 3. Output
655
- shift, scale = (
656
- self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
657
- ).chunk(2, dim=1)
658
- hidden_states = self.norm_out(hidden_states)
659
- # Modulation
660
- hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
661
- hidden_states = self.proj_out(hidden_states)
662
- hidden_states = hidden_states.squeeze(1)
663
-
664
- # unpatchify
665
- hidden_states = hidden_states.reshape(
666
- shape=(-1, height, width, self.config.patch_size, self.config.patch_size, self.out_channels)
667
- )
668
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
669
- output = hidden_states.reshape(
670
- shape=(-1, self.out_channels, height * self.config.patch_size, width * self.config.patch_size)
671
- )
672
-
673
- if not return_dict:
674
- return (output,)
675
-
676
- return Transformer2DModelOutput(sample=output)