twodgirl commited on
Commit
034e8a3
·
verified ·
1 Parent(s): 1800214

Upload folder using huggingface_hub

Browse files
transformer/.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "OmniGen2Transformer2DModel",
3
+ "_diffusers_version": "0.33.1",
4
+ "axes_dim_rope": [
5
+ 40,
6
+ 40,
7
+ 40
8
+ ],
9
+ "axes_lens": [
10
+ 1024,
11
+ 1664,
12
+ 1664
13
+ ],
14
+ "ffn_dim_multiplier": null,
15
+ "hidden_size": 2520,
16
+ "in_channels": 16,
17
+ "multiple_of": 256,
18
+ "norm_eps": 1e-05,
19
+ "num_attention_heads": 21,
20
+ "num_kv_heads": 7,
21
+ "num_layers": 32,
22
+ "num_refiner_layers": 2,
23
+ "out_channels": null,
24
+ "patch_size": 2,
25
+ "text_feat_dim": 2048,
26
+ "timestep_scale": 1000.0
27
+ }
transformer/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "OmniGen2Transformer2DModel",
3
+ "_diffusers_version": "0.33.1",
4
+ "axes_dim_rope": [
5
+ 40,
6
+ 40,
7
+ 40
8
+ ],
9
+ "axes_lens": [
10
+ 1024,
11
+ 1664,
12
+ 1664
13
+ ],
14
+ "ffn_dim_multiplier": null,
15
+ "hidden_size": 2520,
16
+ "in_channels": 16,
17
+ "multiple_of": 256,
18
+ "norm_eps": 1e-05,
19
+ "num_attention_heads": 21,
20
+ "num_kv_heads": 7,
21
+ "num_layers": 32,
22
+ "num_refiner_layers": 2,
23
+ "out_channels": null,
24
+ "patch_size": 2,
25
+ "text_feat_dim": 2048,
26
+ "timestep_scale": 1000.0
27
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06d0dc5cd21197b5fedfd3e3a2b0a4dd49048bc165944076b3107adb00a42bbf
3
+ size 7798909776
transformer/transformer_omnigen2.py ADDED
@@ -0,0 +1,2104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import itertools
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, repeat
11
+
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.loaders import PeftAdapterMixin
14
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
15
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
16
+ from diffusers.models.attention_processor import Attention
17
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+ from diffusers.models.embeddings import get_1d_rotary_pos_embed
20
+ from diffusers.models.activations import get_activation
21
+ from diffusers.models.embeddings import Timesteps
22
+
23
+ import importlib.util
24
+ import sys
25
+
26
+ # The package importlib_metadata is in a different place, depending on the python version.
27
+ if sys.version_info < (3, 8):
28
+ import importlib_metadata
29
+ else:
30
+ import importlib.metadata as importlib_metadata
31
+
32
+ def _is_package_available(pkg_name: str):
33
+ pkg_exists = importlib.util.find_spec(pkg_name) is not None
34
+ pkg_version = "N/A"
35
+
36
+ if pkg_exists:
37
+ try:
38
+ pkg_version = importlib_metadata.version(pkg_name)
39
+ except (ImportError, importlib_metadata.PackageNotFoundError):
40
+ pkg_exists = False
41
+
42
+ return pkg_exists, pkg_version
43
+
44
+ _triton_available, _triton_version = _is_package_available("triton")
45
+ _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
46
+
47
+ def is_triton_available():
48
+ return _triton_available
49
+
50
+ def is_flash_attn_available():
51
+ return _flash_attn_available
52
+
53
+ if is_triton_available():
54
+ # from ...ops.triton.layer_norm import RMSNorm
55
+ import triton
56
+ import triton.language as tl
57
+
58
+
59
+ from typing import Callable
60
+
61
+
62
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
63
+ def decorator(*args, **kwargs):
64
+ if cuda_amp_deprecated:
65
+ kwargs["device_type"] = "cuda"
66
+ return dec(*args, **kwargs)
67
+ return decorator
68
+
69
+
70
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
71
+ deprecated = True
72
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
73
+ else:
74
+ deprecated = False
75
+ from torch.cuda.amp import custom_fwd, custom_bwd
76
+
77
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
78
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)
79
+
80
+
81
+ def triton_autotune_configs():
82
+ # Return configs with a valid warp count for the current device
83
+ configs=[]
84
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
85
+ max_threads_per_block=1024
86
+ # Default to warp size 32 if not defined by device
87
+ warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
88
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
89
+ warp_count=1
90
+ while warp_count*warp_size <= max_threads_per_block:
91
+ configs.append(triton.Config({}, num_warps=warp_count))
92
+ warp_count*=2
93
+ return configs
94
+
95
+ @triton.autotune(
96
+ configs=triton_autotune_configs(),
97
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
98
+ )
99
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
100
+ # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
101
+ @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
102
+ @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
103
+ @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
104
+ @triton.jit
105
+ def _layer_norm_fwd_1pass_kernel(
106
+ X, # pointer to the input
107
+ Y, # pointer to the output
108
+ W, # pointer to the weights
109
+ B, # pointer to the biases
110
+ RESIDUAL, # pointer to the residual
111
+ X1,
112
+ W1,
113
+ B1,
114
+ Y1,
115
+ RESIDUAL_OUT, # pointer to the residual
116
+ ROWSCALE,
117
+ SEEDS, # Dropout seeds for each row
118
+ DROPOUT_MASK,
119
+ Mean, # pointer to the mean
120
+ Rstd, # pointer to the 1/std
121
+ stride_x_row, # how much to increase the pointer when moving by 1 row
122
+ stride_y_row,
123
+ stride_res_row,
124
+ stride_res_out_row,
125
+ stride_x1_row,
126
+ stride_y1_row,
127
+ M, # number of rows in X
128
+ N, # number of columns in X
129
+ eps, # epsilon to avoid division by zero
130
+ dropout_p, # Dropout probability
131
+ zero_centered_weight, # If true, add 1.0 to the weight
132
+ IS_RMS_NORM: tl.constexpr,
133
+ BLOCK_N: tl.constexpr,
134
+ HAS_RESIDUAL: tl.constexpr,
135
+ STORE_RESIDUAL_OUT: tl.constexpr,
136
+ HAS_BIAS: tl.constexpr,
137
+ HAS_DROPOUT: tl.constexpr,
138
+ STORE_DROPOUT_MASK: tl.constexpr,
139
+ HAS_ROWSCALE: tl.constexpr,
140
+ HAS_X1: tl.constexpr,
141
+ HAS_W1: tl.constexpr,
142
+ HAS_B1: tl.constexpr,
143
+ ):
144
+ # Map the program id to the row of X and Y it should compute.
145
+ row = tl.program_id(0)
146
+ X += row * stride_x_row
147
+ Y += row * stride_y_row
148
+ if HAS_RESIDUAL:
149
+ RESIDUAL += row * stride_res_row
150
+ if STORE_RESIDUAL_OUT:
151
+ RESIDUAL_OUT += row * stride_res_out_row
152
+ if HAS_X1:
153
+ X1 += row * stride_x1_row
154
+ if HAS_W1:
155
+ Y1 += row * stride_y1_row
156
+ # Compute mean and variance
157
+ cols = tl.arange(0, BLOCK_N)
158
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
159
+ if HAS_ROWSCALE:
160
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
161
+ x *= rowscale
162
+ if HAS_DROPOUT:
163
+ # Compute dropout mask
164
+ # 7 rounds is good enough, and reduces register pressure
165
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
166
+ x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
167
+ if STORE_DROPOUT_MASK:
168
+ tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
169
+ if HAS_X1:
170
+ x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
171
+ if HAS_ROWSCALE:
172
+ rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
173
+ x1 *= rowscale
174
+ if HAS_DROPOUT:
175
+ # Compute dropout mask
176
+ # 7 rounds is good enough, and reduces register pressure
177
+ keep_mask = (
178
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
179
+ )
180
+ x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
181
+ if STORE_DROPOUT_MASK:
182
+ tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
183
+ x += x1
184
+ if HAS_RESIDUAL:
185
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
186
+ x += residual
187
+ if STORE_RESIDUAL_OUT:
188
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
189
+ if not IS_RMS_NORM:
190
+ mean = tl.sum(x, axis=0) / N
191
+ tl.store(Mean + row, mean)
192
+ xbar = tl.where(cols < N, x - mean, 0.0)
193
+ var = tl.sum(xbar * xbar, axis=0) / N
194
+ else:
195
+ xbar = tl.where(cols < N, x, 0.0)
196
+ var = tl.sum(xbar * xbar, axis=0) / N
197
+ rstd = 1 / tl.sqrt(var + eps)
198
+ tl.store(Rstd + row, rstd)
199
+ # Normalize and apply linear transformation
200
+ mask = cols < N
201
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
202
+ if zero_centered_weight:
203
+ w += 1.0
204
+ if HAS_BIAS:
205
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
206
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
207
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
208
+ # Write output
209
+ tl.store(Y + cols, y, mask=mask)
210
+ if HAS_W1:
211
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
212
+ if zero_centered_weight:
213
+ w1 += 1.0
214
+ if HAS_B1:
215
+ b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
216
+ y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
217
+ tl.store(Y1 + cols, y1, mask=mask)
218
+
219
+
220
+ def _layer_norm_fwd(
221
+ x,
222
+ weight,
223
+ bias,
224
+ eps,
225
+ residual=None,
226
+ x1=None,
227
+ weight1=None,
228
+ bias1=None,
229
+ dropout_p=0.0,
230
+ rowscale=None,
231
+ out_dtype=None,
232
+ residual_dtype=None,
233
+ zero_centered_weight=False,
234
+ is_rms_norm=False,
235
+ return_dropout_mask=False,
236
+ out=None,
237
+ residual_out=None
238
+ ):
239
+ if residual is not None:
240
+ residual_dtype = residual.dtype
241
+ M, N = x.shape
242
+ assert x.stride(-1) == 1
243
+ if residual is not None:
244
+ assert residual.stride(-1) == 1
245
+ assert residual.shape == (M, N)
246
+ assert weight.shape == (N,)
247
+ assert weight.stride(-1) == 1
248
+ if bias is not None:
249
+ assert bias.stride(-1) == 1
250
+ assert bias.shape == (N,)
251
+ if x1 is not None:
252
+ assert x1.shape == x.shape
253
+ assert rowscale is None
254
+ assert x1.stride(-1) == 1
255
+ if weight1 is not None:
256
+ assert weight1.shape == (N,)
257
+ assert weight1.stride(-1) == 1
258
+ if bias1 is not None:
259
+ assert bias1.shape == (N,)
260
+ assert bias1.stride(-1) == 1
261
+ if rowscale is not None:
262
+ assert rowscale.is_contiguous()
263
+ assert rowscale.shape == (M,)
264
+ # allocate output
265
+ if out is None:
266
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
267
+ else:
268
+ assert out.shape == x.shape
269
+ assert out.stride(-1) == 1
270
+ if weight1 is not None:
271
+ y1 = torch.empty_like(out)
272
+ assert y1.stride(-1) == 1
273
+ else:
274
+ y1 = None
275
+ if (
276
+ residual is not None
277
+ or (residual_dtype is not None and residual_dtype != x.dtype)
278
+ or dropout_p > 0.0
279
+ or rowscale is not None
280
+ or x1 is not None
281
+ ):
282
+ if residual_out is None:
283
+ residual_out = torch.empty(
284
+ M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
285
+ )
286
+ else:
287
+ assert residual_out.shape == x.shape
288
+ assert residual_out.stride(-1) == 1
289
+ else:
290
+ residual_out = None
291
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
292
+ rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
293
+ if dropout_p > 0.0:
294
+ seeds = torch.randint(
295
+ 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
296
+ )
297
+ else:
298
+ seeds = None
299
+ if return_dropout_mask and dropout_p > 0.0:
300
+ dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
301
+ else:
302
+ dropout_mask = None
303
+ # Less than 64KB per feature: enqueue fused kernel
304
+ MAX_FUSED_SIZE = 65536 // x.element_size()
305
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
306
+ if N > BLOCK_N:
307
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
308
+ with torch.cuda.device(x.device.index):
309
+ _layer_norm_fwd_1pass_kernel[(M,)](
310
+ x,
311
+ out,
312
+ weight,
313
+ bias,
314
+ residual,
315
+ x1,
316
+ weight1,
317
+ bias1,
318
+ y1,
319
+ residual_out,
320
+ rowscale,
321
+ seeds,
322
+ dropout_mask,
323
+ mean,
324
+ rstd,
325
+ x.stride(0),
326
+ out.stride(0),
327
+ residual.stride(0) if residual is not None else 0,
328
+ residual_out.stride(0) if residual_out is not None else 0,
329
+ x1.stride(0) if x1 is not None else 0,
330
+ y1.stride(0) if y1 is not None else 0,
331
+ M,
332
+ N,
333
+ eps,
334
+ dropout_p,
335
+ zero_centered_weight,
336
+ is_rms_norm,
337
+ BLOCK_N,
338
+ residual is not None,
339
+ residual_out is not None,
340
+ bias is not None,
341
+ dropout_p > 0.0,
342
+ dropout_mask is not None,
343
+ rowscale is not None,
344
+ )
345
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
346
+ if dropout_mask is not None and x1 is not None:
347
+ dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
348
+ else:
349
+ dropout_mask1 = None
350
+ return (
351
+ out,
352
+ y1,
353
+ mean,
354
+ rstd,
355
+ residual_out if residual_out is not None else x,
356
+ seeds,
357
+ dropout_mask,
358
+ dropout_mask1,
359
+ )
360
+
361
+ @triton.autotune(
362
+ configs=triton_autotune_configs(),
363
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
364
+ )
365
+ # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
366
+ # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
367
+ # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
368
+ @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
369
+ @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
370
+ @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
371
+ @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
372
+ @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
373
+ @triton.jit
374
+ def _layer_norm_bwd_kernel(
375
+ X, # pointer to the input
376
+ W, # pointer to the weights
377
+ B, # pointer to the biases
378
+ Y, # pointer to the output to be recomputed
379
+ DY, # pointer to the output gradient
380
+ DX, # pointer to the input gradient
381
+ DW, # pointer to the partial sum of weights gradient
382
+ DB, # pointer to the partial sum of biases gradient
383
+ DRESIDUAL,
384
+ W1,
385
+ DY1,
386
+ DX1,
387
+ DW1,
388
+ DB1,
389
+ DRESIDUAL_IN,
390
+ ROWSCALE,
391
+ SEEDS,
392
+ Mean, # pointer to the mean
393
+ Rstd, # pointer to the 1/std
394
+ stride_x_row, # how much to increase the pointer when moving by 1 row
395
+ stride_y_row,
396
+ stride_dy_row,
397
+ stride_dx_row,
398
+ stride_dres_row,
399
+ stride_dy1_row,
400
+ stride_dx1_row,
401
+ stride_dres_in_row,
402
+ M, # number of rows in X
403
+ N, # number of columns in X
404
+ eps, # epsilon to avoid division by zero
405
+ dropout_p,
406
+ zero_centered_weight,
407
+ rows_per_program,
408
+ IS_RMS_NORM: tl.constexpr,
409
+ BLOCK_N: tl.constexpr,
410
+ HAS_DRESIDUAL: tl.constexpr,
411
+ STORE_DRESIDUAL: tl.constexpr,
412
+ HAS_BIAS: tl.constexpr,
413
+ HAS_DROPOUT: tl.constexpr,
414
+ HAS_ROWSCALE: tl.constexpr,
415
+ HAS_DY1: tl.constexpr,
416
+ HAS_DX1: tl.constexpr,
417
+ HAS_B1: tl.constexpr,
418
+ RECOMPUTE_OUTPUT: tl.constexpr,
419
+ ):
420
+ # Map the program id to the elements of X, DX, and DY it should compute.
421
+ row_block_id = tl.program_id(0)
422
+ row_start = row_block_id * rows_per_program
423
+ # Do not early exit if row_start >= M, because we need to write DW and DB
424
+ cols = tl.arange(0, BLOCK_N)
425
+ mask = cols < N
426
+ X += row_start * stride_x_row
427
+ if HAS_DRESIDUAL:
428
+ DRESIDUAL += row_start * stride_dres_row
429
+ if STORE_DRESIDUAL:
430
+ DRESIDUAL_IN += row_start * stride_dres_in_row
431
+ DY += row_start * stride_dy_row
432
+ DX += row_start * stride_dx_row
433
+ if HAS_DY1:
434
+ DY1 += row_start * stride_dy1_row
435
+ if HAS_DX1:
436
+ DX1 += row_start * stride_dx1_row
437
+ if RECOMPUTE_OUTPUT:
438
+ Y += row_start * stride_y_row
439
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
440
+ if zero_centered_weight:
441
+ w += 1.0
442
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
443
+ b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
444
+ if HAS_DY1:
445
+ w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
446
+ if zero_centered_weight:
447
+ w1 += 1.0
448
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
449
+ if HAS_BIAS:
450
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
451
+ if HAS_DY1:
452
+ dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
453
+ if HAS_B1:
454
+ db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
455
+ row_end = min((row_block_id + 1) * rows_per_program, M)
456
+ for row in range(row_start, row_end):
457
+ # Load data to SRAM
458
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
459
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
460
+ if HAS_DY1:
461
+ dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
462
+ if not IS_RMS_NORM:
463
+ mean = tl.load(Mean + row)
464
+ rstd = tl.load(Rstd + row)
465
+ # Compute dx
466
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
467
+ xhat = tl.where(mask, xhat, 0.0)
468
+ if RECOMPUTE_OUTPUT:
469
+ y = xhat * w + b if HAS_BIAS else xhat * w
470
+ tl.store(Y + cols, y, mask=mask)
471
+ wdy = w * dy
472
+ dw += dy * xhat
473
+ if HAS_BIAS:
474
+ db += dy
475
+ if HAS_DY1:
476
+ wdy += w1 * dy1
477
+ dw1 += dy1 * xhat
478
+ if HAS_B1:
479
+ db1 += dy1
480
+ if not IS_RMS_NORM:
481
+ c1 = tl.sum(xhat * wdy, axis=0) / N
482
+ c2 = tl.sum(wdy, axis=0) / N
483
+ dx = (wdy - (xhat * c1 + c2)) * rstd
484
+ else:
485
+ c1 = tl.sum(xhat * wdy, axis=0) / N
486
+ dx = (wdy - xhat * c1) * rstd
487
+ if HAS_DRESIDUAL:
488
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
489
+ dx += dres
490
+ # Write dx
491
+ if STORE_DRESIDUAL:
492
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
493
+ if HAS_DX1:
494
+ if HAS_DROPOUT:
495
+ keep_mask = (
496
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
497
+ )
498
+ dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
499
+ else:
500
+ dx1 = dx
501
+ tl.store(DX1 + cols, dx1, mask=mask)
502
+ if HAS_DROPOUT:
503
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
504
+ dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
505
+ if HAS_ROWSCALE:
506
+ rowscale = tl.load(ROWSCALE + row).to(tl.float32)
507
+ dx *= rowscale
508
+ tl.store(DX + cols, dx, mask=mask)
509
+
510
+ X += stride_x_row
511
+ if HAS_DRESIDUAL:
512
+ DRESIDUAL += stride_dres_row
513
+ if STORE_DRESIDUAL:
514
+ DRESIDUAL_IN += stride_dres_in_row
515
+ if RECOMPUTE_OUTPUT:
516
+ Y += stride_y_row
517
+ DY += stride_dy_row
518
+ DX += stride_dx_row
519
+ if HAS_DY1:
520
+ DY1 += stride_dy1_row
521
+ if HAS_DX1:
522
+ DX1 += stride_dx1_row
523
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
524
+ if HAS_BIAS:
525
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
526
+ if HAS_DY1:
527
+ tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
528
+ if HAS_B1:
529
+ tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
530
+
531
+
532
+ def _layer_norm_bwd(
533
+ dy,
534
+ x,
535
+ weight,
536
+ bias,
537
+ eps,
538
+ mean,
539
+ rstd,
540
+ dresidual=None,
541
+ dy1=None,
542
+ weight1=None,
543
+ bias1=None,
544
+ seeds=None,
545
+ dropout_p=0.0,
546
+ rowscale=None,
547
+ has_residual=False,
548
+ has_x1=False,
549
+ zero_centered_weight=False,
550
+ is_rms_norm=False,
551
+ x_dtype=None,
552
+ recompute_output=False,
553
+ ):
554
+ M, N = x.shape
555
+ assert x.stride(-1) == 1
556
+ assert dy.stride(-1) == 1
557
+ assert dy.shape == (M, N)
558
+ if dresidual is not None:
559
+ assert dresidual.stride(-1) == 1
560
+ assert dresidual.shape == (M, N)
561
+ assert weight.shape == (N,)
562
+ assert weight.stride(-1) == 1
563
+ if bias is not None:
564
+ assert bias.stride(-1) == 1
565
+ assert bias.shape == (N,)
566
+ if dy1 is not None:
567
+ assert weight1 is not None
568
+ assert dy1.shape == dy.shape
569
+ assert dy1.stride(-1) == 1
570
+ if weight1 is not None:
571
+ assert weight1.shape == (N,)
572
+ assert weight1.stride(-1) == 1
573
+ if bias1 is not None:
574
+ assert bias1.shape == (N,)
575
+ assert bias1.stride(-1) == 1
576
+ if seeds is not None:
577
+ assert seeds.is_contiguous()
578
+ assert seeds.shape == (M if not has_x1 else M * 2,)
579
+ if rowscale is not None:
580
+ assert rowscale.is_contiguous()
581
+ assert rowscale.shape == (M,)
582
+ # allocate output
583
+ dx = (
584
+ torch.empty_like(x)
585
+ if x_dtype is None
586
+ else torch.empty(M, N, dtype=x_dtype, device=x.device)
587
+ )
588
+ dresidual_in = (
589
+ torch.empty_like(x)
590
+ if has_residual
591
+ and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
592
+ else None
593
+ )
594
+ dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
595
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
596
+ if recompute_output:
597
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
598
+
599
+ # Less than 64KB per feature: enqueue fused kernel
600
+ MAX_FUSED_SIZE = 65536 // x.element_size()
601
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
602
+ if N > BLOCK_N:
603
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
604
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
605
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
606
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
607
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
608
+ _db = (
609
+ torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
610
+ if bias is not None
611
+ else None
612
+ )
613
+ _dw1 = torch.empty_like(_dw) if weight1 is not None else None
614
+ _db1 = torch.empty_like(_db) if bias1 is not None else None
615
+ rows_per_program = math.ceil(M / sm_count)
616
+ grid = (sm_count,)
617
+ with torch.cuda.device(x.device.index):
618
+ _layer_norm_bwd_kernel[grid](
619
+ x,
620
+ weight,
621
+ bias,
622
+ y,
623
+ dy,
624
+ dx,
625
+ _dw,
626
+ _db,
627
+ dresidual,
628
+ weight1,
629
+ dy1,
630
+ dx1,
631
+ _dw1,
632
+ _db1,
633
+ dresidual_in,
634
+ rowscale,
635
+ seeds,
636
+ mean,
637
+ rstd,
638
+ x.stride(0),
639
+ 0 if not recompute_output else y.stride(0),
640
+ dy.stride(0),
641
+ dx.stride(0),
642
+ dresidual.stride(0) if dresidual is not None else 0,
643
+ dy1.stride(0) if dy1 is not None else 0,
644
+ dx1.stride(0) if dx1 is not None else 0,
645
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
646
+ M,
647
+ N,
648
+ eps,
649
+ dropout_p,
650
+ zero_centered_weight,
651
+ rows_per_program,
652
+ is_rms_norm,
653
+ BLOCK_N,
654
+ dresidual is not None,
655
+ dresidual_in is not None,
656
+ bias is not None,
657
+ dropout_p > 0.0,
658
+ )
659
+ dw = _dw.sum(0).to(weight.dtype)
660
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
661
+ dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
662
+ db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
663
+ # Don't need to compute dresidual_in separately in this case
664
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
665
+ dresidual_in = dx
666
+ if has_x1 and dropout_p == 0.0:
667
+ dx1 = dx
668
+ return (
669
+ (dx, dw, db, dresidual_in, dx1, dw1, db1)
670
+ if not recompute_output
671
+ else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
672
+ )
673
+
674
+ class LayerNormFn(torch.autograd.Function):
675
+ @staticmethod
676
+ def forward(
677
+ ctx,
678
+ x,
679
+ weight,
680
+ bias,
681
+ residual=None,
682
+ x1=None,
683
+ weight1=None,
684
+ bias1=None,
685
+ eps=1e-6,
686
+ dropout_p=0.0,
687
+ rowscale=None,
688
+ prenorm=False,
689
+ residual_in_fp32=False,
690
+ zero_centered_weight=False,
691
+ is_rms_norm=False,
692
+ return_dropout_mask=False,
693
+ out=None,
694
+ residual_out=None
695
+ ):
696
+ x_shape_og = x.shape
697
+ # Check for zero sequence length
698
+ if x.numel() == 0:
699
+ ctx.zero_seq_length = True
700
+ # Only save minimal required tensors for backward
701
+ # ctx.save_for_backward(weight, bias, weight1, bias1)
702
+ ctx.x_shape_og = x_shape_og
703
+ ctx.weight_shape = weight.shape
704
+ ctx.weight_dtype = weight.dtype
705
+ ctx.weight_device = weight.device
706
+
707
+ ctx.has_bias = bias is not None
708
+ ctx.bias_shape = bias.shape if bias is not None else None
709
+ ctx.bias_dtype = bias.dtype if bias is not None else None
710
+ ctx.bias_device = bias.device if bias is not None else None
711
+
712
+ ctx.has_weight1 = weight1 is not None
713
+ ctx.weight1_shape = weight1.shape if weight1 is not None else None
714
+ ctx.weight1_dtype = weight1.dtype if weight1 is not None else None
715
+ ctx.weight1_device = weight1.device if weight1 is not None else None
716
+
717
+ ctx.has_bias1 = bias1 is not None
718
+ ctx.bias1_shape = bias1.shape if bias1 is not None else None
719
+ ctx.bias1_dtype = bias1.dtype if bias1 is not None else None
720
+ ctx.bias1_device = bias1.device if bias1 is not None else None
721
+
722
+ ctx.has_residual = residual is not None
723
+ ctx.has_x1 = x1 is not None
724
+ ctx.dropout_p = dropout_p
725
+
726
+ # Handle output tensors with correct dtype
727
+ y = x # Preserve input tensor properties
728
+ y1 = torch.empty_like(x) if x1 is not None else None
729
+
730
+ # Only create residual_out if prenorm is True
731
+ residual_out = torch.empty(x.shape,
732
+ dtype=torch.float32 if residual_in_fp32 else x.dtype,
733
+ device=x.device) if prenorm else None
734
+
735
+ # Handle dropout masks
736
+ dropout_mask = None
737
+ dropout_mask1 = None
738
+ if return_dropout_mask:
739
+ dropout_mask = torch.empty_like(x, dtype=torch.uint8)
740
+ if x1 is not None:
741
+ dropout_mask1 = torch.empty_like(x, dtype=torch.uint8)
742
+
743
+ # Return based on configuration
744
+ if not return_dropout_mask:
745
+ if weight1 is None:
746
+ return y if not prenorm else (y, residual_out)
747
+ else:
748
+ return (y, y1) if not prenorm else (y, y1, residual_out)
749
+ else:
750
+ if weight1 is None:
751
+ return ((y, dropout_mask, dropout_mask1) if not prenorm
752
+ else (y, residual_out, dropout_mask, dropout_mask1))
753
+ else:
754
+ return ((y, y1, dropout_mask, dropout_mask1) if not prenorm
755
+ else (y, y1, residual_out, dropout_mask, dropout_mask1))
756
+
757
+ ctx.zero_seq_length = False
758
+ # reshape input data into 2D tensor
759
+ x = x.reshape(-1, x.shape[-1])
760
+ if x.stride(-1) != 1:
761
+ x = x.contiguous()
762
+ if residual is not None:
763
+ assert residual.shape == x_shape_og
764
+ residual = residual.reshape(-1, residual.shape[-1])
765
+ if residual.stride(-1) != 1:
766
+ residual = residual.contiguous()
767
+ if x1 is not None:
768
+ assert x1.shape == x_shape_og
769
+ assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
770
+ x1 = x1.reshape(-1, x1.shape[-1])
771
+ if x1.stride(-1) != 1:
772
+ x1 = x1.contiguous()
773
+ weight = weight.contiguous()
774
+ if bias is not None:
775
+ bias = bias.contiguous()
776
+ if weight1 is not None:
777
+ weight1 = weight1.contiguous()
778
+ if bias1 is not None:
779
+ bias1 = bias1.contiguous()
780
+ if rowscale is not None:
781
+ rowscale = rowscale.reshape(-1).contiguous()
782
+ residual_dtype = (
783
+ residual.dtype
784
+ if residual is not None
785
+ else (torch.float32 if residual_in_fp32 else None)
786
+ )
787
+ if out is not None:
788
+ out = out.reshape(-1, out.shape[-1])
789
+ if residual_out is not None:
790
+ residual_out = residual_out.reshape(-1, residual_out.shape[-1])
791
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
792
+ x,
793
+ weight,
794
+ bias,
795
+ eps,
796
+ residual,
797
+ x1,
798
+ weight1,
799
+ bias1,
800
+ dropout_p=dropout_p,
801
+ rowscale=rowscale,
802
+ residual_dtype=residual_dtype,
803
+ zero_centered_weight=zero_centered_weight,
804
+ is_rms_norm=is_rms_norm,
805
+ return_dropout_mask=return_dropout_mask,
806
+ out=out,
807
+ residual_out=residual_out
808
+ )
809
+ ctx.save_for_backward(
810
+ residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
811
+ )
812
+ ctx.x_shape_og = x_shape_og
813
+ ctx.eps = eps
814
+ ctx.dropout_p = dropout_p
815
+ ctx.is_rms_norm = is_rms_norm
816
+ ctx.has_residual = residual is not None
817
+ ctx.has_x1 = x1 is not None
818
+ ctx.prenorm = prenorm
819
+ ctx.x_dtype = x.dtype
820
+ ctx.zero_centered_weight = zero_centered_weight
821
+ y = y.reshape(x_shape_og)
822
+ y1 = y1.reshape(x_shape_og) if y1 is not None else None
823
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
824
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
825
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
826
+ if not return_dropout_mask:
827
+ if weight1 is None:
828
+ return y if not prenorm else (y, residual_out)
829
+ else:
830
+ return (y, y1) if not prenorm else (y, y1, residual_out)
831
+ else:
832
+ if weight1 is None:
833
+ return (
834
+ (y, dropout_mask, dropout_mask1)
835
+ if not prenorm
836
+ else (y, residual_out, dropout_mask, dropout_mask1)
837
+ )
838
+ else:
839
+ return (
840
+ (y, y1, dropout_mask, dropout_mask1)
841
+ if not prenorm
842
+ else (y, y1, residual_out, dropout_mask, dropout_mask1)
843
+ )
844
+
845
+ @staticmethod
846
+ def backward(ctx, dy, *args):
847
+ if ctx.zero_seq_length:
848
+ return (
849
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device),
850
+ torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device),
851
+ torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None,
852
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None,
853
+ torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None,
854
+ torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None,
855
+ torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None,
856
+ None,
857
+ None,
858
+ None,
859
+ None,
860
+ None,
861
+ None,
862
+ None,
863
+ None,
864
+ None,
865
+ None,
866
+ )
867
+
868
+ x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
869
+ dy = dy.reshape(-1, dy.shape[-1])
870
+ if dy.stride(-1) != 1:
871
+ dy = dy.contiguous()
872
+ assert dy.shape == x.shape
873
+ if weight1 is not None:
874
+ dy1, args = args[0], args[1:]
875
+ dy1 = dy1.reshape(-1, dy1.shape[-1])
876
+ if dy1.stride(-1) != 1:
877
+ dy1 = dy1.contiguous()
878
+ assert dy1.shape == x.shape
879
+ else:
880
+ dy1 = None
881
+ if ctx.prenorm:
882
+ dresidual = args[0]
883
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
884
+ if dresidual.stride(-1) != 1:
885
+ dresidual = dresidual.contiguous()
886
+ assert dresidual.shape == x.shape
887
+ else:
888
+ dresidual = None
889
+
890
+ dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
891
+ dy,
892
+ x,
893
+ weight,
894
+ bias,
895
+ ctx.eps,
896
+ mean,
897
+ rstd,
898
+ dresidual,
899
+ dy1,
900
+ weight1,
901
+ bias1,
902
+ seeds,
903
+ ctx.dropout_p,
904
+ rowscale,
905
+ ctx.has_residual,
906
+ ctx.has_x1,
907
+ ctx.zero_centered_weight,
908
+ ctx.is_rms_norm,
909
+ x_dtype=ctx.x_dtype,
910
+ )
911
+ return (
912
+ dx.reshape(ctx.x_shape_og),
913
+ dw,
914
+ db,
915
+ dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
916
+ dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
917
+ dw1,
918
+ db1,
919
+ None,
920
+ None,
921
+ None,
922
+ None,
923
+ None,
924
+ None,
925
+ None,
926
+ None,
927
+ None,
928
+ None,
929
+ )
930
+
931
+ def rms_norm_fn(
932
+ x,
933
+ weight,
934
+ bias,
935
+ residual=None,
936
+ x1=None,
937
+ weight1=None,
938
+ bias1=None,
939
+ eps=1e-6,
940
+ dropout_p=0.0,
941
+ rowscale=None,
942
+ prenorm=False,
943
+ residual_in_fp32=False,
944
+ zero_centered_weight=False,
945
+ return_dropout_mask=False,
946
+ out=None,
947
+ residual_out=None
948
+ ):
949
+ return LayerNormFn.apply(
950
+ x,
951
+ weight,
952
+ bias,
953
+ residual,
954
+ x1,
955
+ weight1,
956
+ bias1,
957
+ eps,
958
+ dropout_p,
959
+ rowscale,
960
+ prenorm,
961
+ residual_in_fp32,
962
+ zero_centered_weight,
963
+ True,
964
+ return_dropout_mask,
965
+ out,
966
+ residual_out
967
+ )
968
+
969
+ class RMSNorm(torch.nn.Module):
970
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
971
+ device=None, dtype=None):
972
+ factory_kwargs = {"device": device, "dtype": dtype}
973
+ super().__init__()
974
+ self.eps = eps
975
+ if dropout_p > 0.0:
976
+ self.drop = torch.nn.Dropout(dropout_p)
977
+ else:
978
+ self.drop = None
979
+ self.zero_centered_weight = zero_centered_weight
980
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
981
+ self.register_parameter("bias", None)
982
+ self.reset_parameters()
983
+
984
+ def reset_parameters(self):
985
+ if not self.zero_centered_weight:
986
+ torch.nn.init.ones_(self.weight)
987
+ else:
988
+ torch.nn.init.zeros_(self.weight)
989
+
990
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
991
+ return rms_norm_fn(
992
+ x,
993
+ self.weight,
994
+ self.bias,
995
+ residual=residual,
996
+ eps=self.eps,
997
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
998
+ prenorm=prenorm,
999
+ residual_in_fp32=residual_in_fp32,
1000
+ zero_centered_weight=self.zero_centered_weight,
1001
+ )
1002
+ else:
1003
+ from torch.nn import RMSNorm
1004
+ warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance")
1005
+
1006
+ def swiglu(x, y):
1007
+ return F.silu(x.float(), inplace=False).to(x.dtype) * y
1008
+
1009
+ logger = logging.get_logger(__name__)
1010
+
1011
+
1012
+ class TimestepEmbedding(nn.Module):
1013
+ def __init__(
1014
+ self,
1015
+ in_channels: int,
1016
+ time_embed_dim: int,
1017
+ act_fn: str = "silu",
1018
+ out_dim: int = None,
1019
+ post_act_fn: Optional[str] = None,
1020
+ cond_proj_dim=None,
1021
+ sample_proj_bias=True,
1022
+ ):
1023
+ super().__init__()
1024
+
1025
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
1026
+
1027
+ if cond_proj_dim is not None:
1028
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
1029
+ else:
1030
+ self.cond_proj = None
1031
+
1032
+ self.act = get_activation(act_fn)
1033
+
1034
+ if out_dim is not None:
1035
+ time_embed_dim_out = out_dim
1036
+ else:
1037
+ time_embed_dim_out = time_embed_dim
1038
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
1039
+
1040
+ if post_act_fn is None:
1041
+ self.post_act = None
1042
+ else:
1043
+ self.post_act = get_activation(post_act_fn)
1044
+
1045
+ self.initialize_weights()
1046
+
1047
+ def initialize_weights(self):
1048
+ nn.init.normal_(self.linear_1.weight, std=0.02)
1049
+ nn.init.zeros_(self.linear_1.bias)
1050
+ nn.init.normal_(self.linear_2.weight, std=0.02)
1051
+ nn.init.zeros_(self.linear_2.bias)
1052
+
1053
+ def forward(self, sample, condition=None):
1054
+ if condition is not None:
1055
+ sample = sample + self.cond_proj(condition)
1056
+ sample = self.linear_1(sample)
1057
+
1058
+ if self.act is not None:
1059
+ sample = self.act(sample)
1060
+
1061
+ sample = self.linear_2(sample)
1062
+
1063
+ if self.post_act is not None:
1064
+ sample = self.post_act(sample)
1065
+ return sample
1066
+
1067
+ def apply_rotary_emb(
1068
+ x: torch.Tensor,
1069
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
1070
+ use_real: bool = True,
1071
+ use_real_unbind_dim: int = -1,
1072
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1073
+ """
1074
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
1075
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
1076
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
1077
+ tensors contain rotary embeddings and are returned as real tensors.
1078
+
1079
+ Args:
1080
+ x (`torch.Tensor`):
1081
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
1082
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
1083
+
1084
+ Returns:
1085
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
1086
+ """
1087
+ if use_real:
1088
+ cos, sin = freqs_cis # [S, D]
1089
+ cos = cos[None, None]
1090
+ sin = sin[None, None]
1091
+ cos, sin = cos.to(x.device), sin.to(x.device)
1092
+
1093
+ if use_real_unbind_dim == -1:
1094
+ # Used for flux, cogvideox, hunyuan-dit
1095
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
1096
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
1097
+ elif use_real_unbind_dim == -2:
1098
+ # Used for Stable Audio, OmniGen and CogView4
1099
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
1100
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
1101
+ else:
1102
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
1103
+
1104
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
1105
+
1106
+ return out
1107
+ else:
1108
+ # used for lumina
1109
+ # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1110
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2))
1111
+ freqs_cis = freqs_cis.unsqueeze(2)
1112
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
1113
+
1114
+ return x_out.type_as(x)
1115
+
1116
+ class OmniGen2RotaryPosEmbed(nn.Module):
1117
+ def __init__(self, theta: int,
1118
+ axes_dim: Tuple[int, int, int],
1119
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1120
+ patch_size: int = 2):
1121
+ super().__init__()
1122
+ self.theta = theta
1123
+ self.axes_dim = axes_dim
1124
+ self.axes_lens = axes_lens
1125
+ self.patch_size = patch_size
1126
+
1127
+ @staticmethod
1128
+ def get_freqs_cis(axes_dim: Tuple[int, int, int],
1129
+ axes_lens: Tuple[int, int, int],
1130
+ theta: int) -> List[torch.Tensor]:
1131
+ freqs_cis = []
1132
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
1133
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
1134
+ emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype)
1135
+ freqs_cis.append(emb)
1136
+ return freqs_cis
1137
+
1138
+ def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor:
1139
+ device = ids.device
1140
+ if ids.device.type == "mps":
1141
+ ids = ids.to("cpu")
1142
+
1143
+ result = []
1144
+ for i in range(len(self.axes_dim)):
1145
+ freqs = freqs_cis[i].to(ids.device)
1146
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
1147
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
1148
+ return torch.cat(result, dim=-1).to(device)
1149
+
1150
+ def forward(
1151
+ self,
1152
+ freqs_cis,
1153
+ attention_mask,
1154
+ l_effective_ref_img_len,
1155
+ l_effective_img_len,
1156
+ ref_img_sizes,
1157
+ img_sizes,
1158
+ device
1159
+ ):
1160
+ batch_size = len(attention_mask)
1161
+ p = self.patch_size
1162
+
1163
+ encoder_seq_len = attention_mask.shape[1]
1164
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
1165
+
1166
+ seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)]
1167
+
1168
+ max_seq_len = max(seq_lengths)
1169
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1170
+ max_img_len = max(l_effective_img_len)
1171
+
1172
+ # Create position IDs
1173
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
1174
+
1175
+ for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
1176
+ # add text position ids
1177
+ position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3")
1178
+
1179
+ pe_shift = cap_seq_len
1180
+ pe_shift_len = cap_seq_len
1181
+
1182
+ if ref_img_sizes[i] is not None:
1183
+ for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]):
1184
+ H, W = ref_img_size
1185
+ ref_H_tokens, ref_W_tokens = H // p, W // p
1186
+ assert ref_H_tokens * ref_W_tokens == ref_img_len
1187
+ # add image position ids
1188
+
1189
+ row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten()
1190
+ col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten()
1191
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift
1192
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids
1193
+ position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids
1194
+
1195
+ pe_shift += max(ref_H_tokens, ref_W_tokens)
1196
+ pe_shift_len += ref_img_len
1197
+
1198
+ H, W = img_sizes[i]
1199
+ H_tokens, W_tokens = H // p, W // p
1200
+ assert H_tokens * W_tokens == l_effective_img_len[i]
1201
+
1202
+ row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten()
1203
+ col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten()
1204
+
1205
+ assert pe_shift_len + l_effective_img_len[i] == seq_len
1206
+ position_ids[i, pe_shift_len: seq_len, 0] = pe_shift
1207
+ position_ids[i, pe_shift_len: seq_len, 1] = row_ids
1208
+ position_ids[i, pe_shift_len: seq_len, 2] = col_ids
1209
+
1210
+ # Get combined rotary embeddings
1211
+ freqs_cis = self._get_freqs_cis(freqs_cis, position_ids)
1212
+
1213
+ # create separate rotary embeddings for captions and images
1214
+ cap_freqs_cis = torch.zeros(
1215
+ batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1216
+ )
1217
+ ref_img_freqs_cis = torch.zeros(
1218
+ batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1219
+ )
1220
+ img_freqs_cis = torch.zeros(
1221
+ batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
1222
+ )
1223
+
1224
+ for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)):
1225
+ cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
1226
+ ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)]
1227
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len]
1228
+
1229
+ return (
1230
+ cap_freqs_cis,
1231
+ ref_img_freqs_cis,
1232
+ img_freqs_cis,
1233
+ freqs_cis,
1234
+ l_effective_cap_len,
1235
+ seq_lengths,
1236
+ )
1237
+
1238
+
1239
+ class LuminaRMSNormZero(nn.Module):
1240
+ """
1241
+ Norm layer adaptive RMS normalization zero.
1242
+
1243
+ Parameters:
1244
+ embedding_dim (`int`): The size of each embedding vector.
1245
+ """
1246
+
1247
+ def __init__(
1248
+ self,
1249
+ embedding_dim: int,
1250
+ norm_eps: float,
1251
+ norm_elementwise_affine: bool,
1252
+ ):
1253
+ super().__init__()
1254
+ self.silu = nn.SiLU()
1255
+ self.linear = nn.Linear(
1256
+ min(embedding_dim, 1024),
1257
+ 4 * embedding_dim,
1258
+ bias=True,
1259
+ )
1260
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
1261
+
1262
+ def forward(
1263
+ self,
1264
+ x: torch.Tensor,
1265
+ emb: Optional[torch.Tensor] = None,
1266
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1267
+ emb = self.linear(self.silu(emb))
1268
+ scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
1269
+ x = self.norm(x) * (1 + scale_msa[:, None])
1270
+ return x, gate_msa, scale_mlp, gate_mlp
1271
+
1272
+
1273
+ class LuminaLayerNormContinuous(nn.Module):
1274
+ def __init__(
1275
+ self,
1276
+ embedding_dim: int,
1277
+ conditioning_embedding_dim: int,
1278
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
1279
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
1280
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
1281
+ # However, this is how it was implemented in the original code, and it's rather likely you should
1282
+ # set `elementwise_affine` to False.
1283
+ elementwise_affine=True,
1284
+ eps=1e-5,
1285
+ bias=True,
1286
+ norm_type="layer_norm",
1287
+ out_dim: Optional[int] = None,
1288
+ ):
1289
+ super().__init__()
1290
+
1291
+ # AdaLN
1292
+ self.silu = nn.SiLU()
1293
+ self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
1294
+
1295
+ if norm_type == "layer_norm":
1296
+ self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
1297
+ elif norm_type == "rms_norm":
1298
+ self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
1299
+ else:
1300
+ raise ValueError(f"unknown norm_type {norm_type}")
1301
+
1302
+ self.linear_2 = None
1303
+ if out_dim is not None:
1304
+ self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
1305
+
1306
+ def forward(
1307
+ self,
1308
+ x: torch.Tensor,
1309
+ conditioning_embedding: torch.Tensor,
1310
+ ) -> torch.Tensor:
1311
+ # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
1312
+ emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
1313
+ scale = emb
1314
+ x = self.norm(x) * (1 + scale)[:, None, :]
1315
+
1316
+ if self.linear_2 is not None:
1317
+ x = self.linear_2(x)
1318
+
1319
+ return x
1320
+
1321
+
1322
+ class LuminaFeedForward(nn.Module):
1323
+ r"""
1324
+ A feed-forward layer.
1325
+
1326
+ Parameters:
1327
+ hidden_size (`int`):
1328
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
1329
+ hidden representations.
1330
+ intermediate_size (`int`): The intermediate dimension of the feedforward layer.
1331
+ multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
1332
+ of this value.
1333
+ ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
1334
+ dimension. Defaults to None.
1335
+ """
1336
+
1337
+ def __init__(
1338
+ self,
1339
+ dim: int,
1340
+ inner_dim: int,
1341
+ multiple_of: Optional[int] = 256,
1342
+ ffn_dim_multiplier: Optional[float] = None,
1343
+ ):
1344
+ super().__init__()
1345
+
1346
+ self.swiglu = swiglu
1347
+
1348
+ # custom hidden_size factor multiplier
1349
+ if ffn_dim_multiplier is not None:
1350
+ inner_dim = int(ffn_dim_multiplier * inner_dim)
1351
+ inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
1352
+
1353
+ self.linear_1 = nn.Linear(
1354
+ dim,
1355
+ inner_dim,
1356
+ bias=False,
1357
+ )
1358
+ self.linear_2 = nn.Linear(
1359
+ inner_dim,
1360
+ dim,
1361
+ bias=False,
1362
+ )
1363
+ self.linear_3 = nn.Linear(
1364
+ dim,
1365
+ inner_dim,
1366
+ bias=False,
1367
+ )
1368
+
1369
+ def forward(self, x):
1370
+ h1, h2 = self.linear_1(x), self.linear_3(x)
1371
+ return self.linear_2(self.swiglu(h1, h2))
1372
+
1373
+
1374
+ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
1375
+ def __init__(
1376
+ self,
1377
+ hidden_size: int = 4096,
1378
+ text_feat_dim: int = 2048,
1379
+ frequency_embedding_size: int = 256,
1380
+ norm_eps: float = 1e-5,
1381
+ timestep_scale: float = 1.0,
1382
+ ) -> None:
1383
+ super().__init__()
1384
+
1385
+ self.time_proj = Timesteps(
1386
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
1387
+ )
1388
+
1389
+ self.timestep_embedder = TimestepEmbedding(
1390
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
1391
+ )
1392
+
1393
+ self.caption_embedder = nn.Sequential(
1394
+ RMSNorm(text_feat_dim, eps=norm_eps),
1395
+ nn.Linear(text_feat_dim, hidden_size, bias=True),
1396
+ )
1397
+
1398
+ self._initialize_weights()
1399
+
1400
+ def _initialize_weights(self):
1401
+ nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
1402
+ nn.init.zeros_(self.caption_embedder[1].bias)
1403
+
1404
+ def forward(
1405
+ self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
1406
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1407
+ timestep_proj = self.time_proj(timestep).to(dtype=dtype)
1408
+ time_embed = self.timestep_embedder(timestep_proj)
1409
+ caption_embed = self.caption_embedder(text_hidden_states)
1410
+ return time_embed, caption_embed
1411
+
1412
+
1413
+ class OmniGen2AttnProcessor:
1414
+ """
1415
+ Processor for implementing scaled dot-product attention.
1416
+
1417
+ This processor is optimized for PyTorch 2.0 and implements:
1418
+ - Flash attention with variable length sequences
1419
+ - Rotary position embeddings (RoPE)
1420
+ - Query-Key normalization
1421
+ - Proportional attention scaling
1422
+
1423
+ Args:
1424
+ None
1425
+
1426
+ Raises:
1427
+ ImportError: If PyTorch version is less than 2.0
1428
+ """
1429
+
1430
+ def __init__(self) -> None:
1431
+ """Initialize the attention processor."""
1432
+ if not hasattr(F, "scaled_dot_product_attention"):
1433
+ raise ImportError(
1434
+ "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. "
1435
+ "Please upgrade PyTorch to version 2.0 or later."
1436
+ )
1437
+
1438
+ def __call__(
1439
+ self,
1440
+ attn: Attention,
1441
+ hidden_states: torch.Tensor,
1442
+ encoder_hidden_states: torch.Tensor,
1443
+ attention_mask: Optional[torch.Tensor] = None,
1444
+ image_rotary_emb: Optional[torch.Tensor] = None,
1445
+ base_sequence_length: Optional[int] = None,
1446
+ ) -> torch.Tensor:
1447
+ """
1448
+ Process attention computation with flash attention.
1449
+
1450
+ Args:
1451
+ attn: Attention module
1452
+ hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim)
1453
+ encoder_hidden_states: Encoder hidden states tensor
1454
+ attention_mask: Optional attention mask tensor
1455
+ image_rotary_emb: Optional rotary embeddings for image tokens
1456
+ base_sequence_length: Optional base sequence length for proportional attention
1457
+
1458
+ Returns:
1459
+ torch.Tensor: Processed hidden states after attention computation
1460
+ """
1461
+ batch_size, sequence_length, _ = hidden_states.shape
1462
+
1463
+ # Get Query-Key-Value Pair
1464
+ query = attn.to_q(hidden_states)
1465
+ key = attn.to_k(encoder_hidden_states)
1466
+ value = attn.to_v(encoder_hidden_states)
1467
+
1468
+ query_dim = query.shape[-1]
1469
+ inner_dim = key.shape[-1]
1470
+ head_dim = query_dim // attn.heads
1471
+ dtype = query.dtype
1472
+
1473
+ # Get key-value heads
1474
+ kv_heads = inner_dim // head_dim
1475
+
1476
+ # Reshape tensors for attention computation
1477
+ query = query.view(batch_size, -1, attn.heads, head_dim)
1478
+ key = key.view(batch_size, -1, kv_heads, head_dim)
1479
+ value = value.view(batch_size, -1, kv_heads, head_dim)
1480
+
1481
+ # Apply Query-Key normalization
1482
+ if attn.norm_q is not None:
1483
+ query = attn.norm_q(query)
1484
+ if attn.norm_k is not None:
1485
+ key = attn.norm_k(key)
1486
+
1487
+ # Apply Rotary Position Embeddings
1488
+ if image_rotary_emb is not None:
1489
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
1490
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
1491
+
1492
+ query, key = query.to(dtype), key.to(dtype)
1493
+
1494
+ # Calculate attention scale
1495
+ if base_sequence_length is not None:
1496
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
1497
+ else:
1498
+ softmax_scale = attn.scale
1499
+
1500
+ # scaled_dot_product_attention expects attention_mask shape to be
1501
+ # (batch, heads, source_length, target_length)
1502
+ if attention_mask is not None:
1503
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
1504
+
1505
+ query = query.transpose(1, 2)
1506
+ key = key.transpose(1, 2)
1507
+ value = value.transpose(1, 2)
1508
+
1509
+ # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6
1510
+ key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
1511
+ value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)
1512
+
1513
+ hidden_states = F.scaled_dot_product_attention(
1514
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
1515
+ )
1516
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1517
+ hidden_states = hidden_states.type_as(query)
1518
+
1519
+ # Apply output projection
1520
+ hidden_states = attn.to_out[0](hidden_states)
1521
+ hidden_states = attn.to_out[1](hidden_states)
1522
+
1523
+ return hidden_states
1524
+
1525
+ class OmniGen2TransformerBlock(nn.Module):
1526
+ """
1527
+ Transformer block for OmniGen2 model.
1528
+
1529
+ This block implements a transformer layer with:
1530
+ - Multi-head attention with flash attention
1531
+ - Feed-forward network with SwiGLU activation
1532
+ - RMS normalization
1533
+ - Optional modulation for conditional generation
1534
+
1535
+ Args:
1536
+ dim: Dimension of the input and output tensors
1537
+ num_attention_heads: Number of attention heads
1538
+ num_kv_heads: Number of key-value heads
1539
+ multiple_of: Multiple of which the hidden dimension should be
1540
+ ffn_dim_multiplier: Multiplier for the feed-forward network dimension
1541
+ norm_eps: Epsilon value for normalization layers
1542
+ modulation: Whether to use modulation for conditional generation
1543
+ use_fused_rms_norm: Whether to use fused RMS normalization
1544
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1545
+ """
1546
+
1547
+ def __init__(
1548
+ self,
1549
+ dim: int,
1550
+ num_attention_heads: int,
1551
+ num_kv_heads: int,
1552
+ multiple_of: int,
1553
+ ffn_dim_multiplier: float,
1554
+ norm_eps: float,
1555
+ modulation: bool = True,
1556
+ ) -> None:
1557
+ """Initialize the transformer block."""
1558
+ super().__init__()
1559
+ self.head_dim = dim // num_attention_heads
1560
+ self.modulation = modulation
1561
+
1562
+ # Initialize attention layer
1563
+ self.attn = Attention(
1564
+ query_dim=dim,
1565
+ cross_attention_dim=None,
1566
+ dim_head=dim // num_attention_heads,
1567
+ qk_norm="rms_norm",
1568
+ heads=num_attention_heads,
1569
+ kv_heads=num_kv_heads,
1570
+ eps=1e-5,
1571
+ bias=False,
1572
+ out_bias=False,
1573
+ processor=OmniGen2AttnProcessor(),
1574
+ )
1575
+
1576
+ # Initialize feed-forward network
1577
+ self.feed_forward = LuminaFeedForward(
1578
+ dim=dim,
1579
+ inner_dim=4 * dim,
1580
+ multiple_of=multiple_of,
1581
+ ffn_dim_multiplier=ffn_dim_multiplier,
1582
+ )
1583
+
1584
+ # Initialize normalization layers
1585
+ if modulation:
1586
+ self.norm1 = LuminaRMSNormZero(
1587
+ embedding_dim=dim,
1588
+ norm_eps=norm_eps,
1589
+ norm_elementwise_affine=True,
1590
+ )
1591
+ else:
1592
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
1593
+
1594
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
1595
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
1596
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
1597
+
1598
+ self.initialize_weights()
1599
+
1600
+ def initialize_weights(self) -> None:
1601
+ """
1602
+ Initialize the weights of the transformer block.
1603
+
1604
+ Uses Xavier uniform initialization for linear layers and zero initialization for biases.
1605
+ """
1606
+ nn.init.xavier_uniform_(self.attn.to_q.weight)
1607
+ nn.init.xavier_uniform_(self.attn.to_k.weight)
1608
+ nn.init.xavier_uniform_(self.attn.to_v.weight)
1609
+ nn.init.xavier_uniform_(self.attn.to_out[0].weight)
1610
+
1611
+ nn.init.xavier_uniform_(self.feed_forward.linear_1.weight)
1612
+ nn.init.xavier_uniform_(self.feed_forward.linear_2.weight)
1613
+ nn.init.xavier_uniform_(self.feed_forward.linear_3.weight)
1614
+
1615
+ if self.modulation:
1616
+ nn.init.zeros_(self.norm1.linear.weight)
1617
+ nn.init.zeros_(self.norm1.linear.bias)
1618
+
1619
+ def forward(
1620
+ self,
1621
+ hidden_states: torch.Tensor,
1622
+ attention_mask: torch.Tensor,
1623
+ image_rotary_emb: torch.Tensor,
1624
+ temb: Optional[torch.Tensor] = None,
1625
+ ) -> torch.Tensor:
1626
+ """
1627
+ Forward pass of the transformer block.
1628
+
1629
+ Args:
1630
+ hidden_states: Input hidden states tensor
1631
+ attention_mask: Attention mask tensor
1632
+ image_rotary_emb: Rotary embeddings for image tokens
1633
+ temb: Optional timestep embedding tensor
1634
+
1635
+ Returns:
1636
+ torch.Tensor: Output hidden states after transformer block processing
1637
+ """
1638
+ if self.modulation:
1639
+ if temb is None:
1640
+ raise ValueError("temb must be provided when modulation is enabled")
1641
+
1642
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
1643
+ attn_output = self.attn(
1644
+ hidden_states=norm_hidden_states,
1645
+ encoder_hidden_states=norm_hidden_states,
1646
+ attention_mask=attention_mask,
1647
+ image_rotary_emb=image_rotary_emb,
1648
+ )
1649
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
1650
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
1651
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
1652
+ else:
1653
+ norm_hidden_states = self.norm1(hidden_states)
1654
+ attn_output = self.attn(
1655
+ hidden_states=norm_hidden_states,
1656
+ encoder_hidden_states=norm_hidden_states,
1657
+ attention_mask=attention_mask,
1658
+ image_rotary_emb=image_rotary_emb,
1659
+ )
1660
+ hidden_states = hidden_states + self.norm2(attn_output)
1661
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
1662
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
1663
+
1664
+ return hidden_states
1665
+
1666
+
1667
+ class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1668
+ """
1669
+ OmniGen2 Transformer 2D Model.
1670
+
1671
+ A transformer-based diffusion model for image generation with:
1672
+ - Patch-based image processing
1673
+ - Rotary position embeddings
1674
+ - Multi-head attention
1675
+ - Conditional generation support
1676
+
1677
+ Args:
1678
+ patch_size: Size of image patches
1679
+ in_channels: Number of input channels
1680
+ out_channels: Number of output channels (defaults to in_channels)
1681
+ hidden_size: Size of hidden layers
1682
+ num_layers: Number of transformer layers
1683
+ num_refiner_layers: Number of refiner layers
1684
+ num_attention_heads: Number of attention heads
1685
+ num_kv_heads: Number of key-value heads
1686
+ multiple_of: Multiple of which the hidden dimension should be
1687
+ ffn_dim_multiplier: Multiplier for feed-forward network dimension
1688
+ norm_eps: Epsilon value for normalization layers
1689
+ axes_dim_rope: Dimensions for rotary position embeddings
1690
+ axes_lens: Lengths for rotary position embeddings
1691
+ text_feat_dim: Dimension of text features
1692
+ timestep_scale: Scale factor for timestep embeddings
1693
+ use_fused_rms_norm: Whether to use fused RMS normalization
1694
+ use_fused_swiglu: Whether to use fused SwiGLU activation
1695
+ """
1696
+
1697
+ _supports_gradient_checkpointing = True
1698
+ _no_split_modules = ["Omnigen2TransformerBlock"]
1699
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
1700
+
1701
+ @register_to_config
1702
+ def __init__(
1703
+ self,
1704
+ patch_size: int = 2,
1705
+ in_channels: int = 16,
1706
+ out_channels: Optional[int] = None,
1707
+ hidden_size: int = 2304,
1708
+ num_layers: int = 26,
1709
+ num_refiner_layers: int = 2,
1710
+ num_attention_heads: int = 24,
1711
+ num_kv_heads: int = 8,
1712
+ multiple_of: int = 256,
1713
+ ffn_dim_multiplier: Optional[float] = None,
1714
+ norm_eps: float = 1e-5,
1715
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
1716
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
1717
+ text_feat_dim: int = 1024,
1718
+ timestep_scale: float = 1.0,
1719
+ ) -> None:
1720
+ """Initialize the OmniGen2 transformer model."""
1721
+ super().__init__()
1722
+
1723
+ # Validate configuration
1724
+ if (hidden_size // num_attention_heads) != sum(axes_dim_rope):
1725
+ raise ValueError(
1726
+ f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) "
1727
+ f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})"
1728
+ )
1729
+
1730
+ self.out_channels = out_channels or in_channels
1731
+
1732
+ # Initialize embeddings
1733
+ self.rope_embedder = OmniGen2RotaryPosEmbed(
1734
+ theta=10000,
1735
+ axes_dim=axes_dim_rope,
1736
+ axes_lens=axes_lens,
1737
+ patch_size=patch_size,
1738
+ )
1739
+
1740
+ self.x_embedder = nn.Linear(
1741
+ in_features=patch_size * patch_size * in_channels,
1742
+ out_features=hidden_size,
1743
+ )
1744
+
1745
+ self.ref_image_patch_embedder = nn.Linear(
1746
+ in_features=patch_size * patch_size * in_channels,
1747
+ out_features=hidden_size,
1748
+ )
1749
+
1750
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
1751
+ hidden_size=hidden_size,
1752
+ text_feat_dim=text_feat_dim,
1753
+ norm_eps=norm_eps,
1754
+ timestep_scale=timestep_scale,
1755
+ )
1756
+
1757
+ # Initialize transformer blocks
1758
+ self.noise_refiner = nn.ModuleList([
1759
+ OmniGen2TransformerBlock(
1760
+ hidden_size,
1761
+ num_attention_heads,
1762
+ num_kv_heads,
1763
+ multiple_of,
1764
+ ffn_dim_multiplier,
1765
+ norm_eps,
1766
+ modulation=True,
1767
+ )
1768
+ for _ in range(num_refiner_layers)
1769
+ ])
1770
+
1771
+ self.ref_image_refiner = nn.ModuleList([
1772
+ OmniGen2TransformerBlock(
1773
+ hidden_size,
1774
+ num_attention_heads,
1775
+ num_kv_heads,
1776
+ multiple_of,
1777
+ ffn_dim_multiplier,
1778
+ norm_eps,
1779
+ modulation=True,
1780
+ )
1781
+ for _ in range(num_refiner_layers)
1782
+ ])
1783
+
1784
+ self.context_refiner = nn.ModuleList(
1785
+ [
1786
+ OmniGen2TransformerBlock(
1787
+ hidden_size,
1788
+ num_attention_heads,
1789
+ num_kv_heads,
1790
+ multiple_of,
1791
+ ffn_dim_multiplier,
1792
+ norm_eps,
1793
+ modulation=False,
1794
+ )
1795
+ for _ in range(num_refiner_layers)
1796
+ ]
1797
+ )
1798
+
1799
+ # 3. Transformer blocks
1800
+ self.layers = nn.ModuleList(
1801
+ [
1802
+ OmniGen2TransformerBlock(
1803
+ hidden_size,
1804
+ num_attention_heads,
1805
+ num_kv_heads,
1806
+ multiple_of,
1807
+ ffn_dim_multiplier,
1808
+ norm_eps,
1809
+ modulation=True,
1810
+ )
1811
+ for _ in range(num_layers)
1812
+ ]
1813
+ )
1814
+
1815
+ # 4. Output norm & projection
1816
+ self.norm_out = LuminaLayerNormContinuous(
1817
+ embedding_dim=hidden_size,
1818
+ conditioning_embedding_dim=min(hidden_size, 1024),
1819
+ elementwise_affine=False,
1820
+ eps=1e-6,
1821
+ bias=True,
1822
+ out_dim=patch_size * patch_size * self.out_channels,
1823
+ )
1824
+
1825
+ # Add learnable embeddings to distinguish different images
1826
+ self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images
1827
+
1828
+ self.gradient_checkpointing = False
1829
+
1830
+ self.initialize_weights()
1831
+
1832
+ def initialize_weights(self) -> None:
1833
+ """
1834
+ Initialize the weights of the model.
1835
+
1836
+ Uses Xavier uniform initialization for linear layers.
1837
+ """
1838
+ nn.init.xavier_uniform_(self.x_embedder.weight)
1839
+ nn.init.constant_(self.x_embedder.bias, 0.0)
1840
+
1841
+ nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight)
1842
+ nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0)
1843
+
1844
+ nn.init.zeros_(self.norm_out.linear_1.weight)
1845
+ nn.init.zeros_(self.norm_out.linear_1.bias)
1846
+ nn.init.zeros_(self.norm_out.linear_2.weight)
1847
+ nn.init.zeros_(self.norm_out.linear_2.bias)
1848
+
1849
+ nn.init.normal_(self.image_index_embedding, std=0.02)
1850
+
1851
+ def img_patch_embed_and_refine(
1852
+ self,
1853
+ hidden_states,
1854
+ ref_image_hidden_states,
1855
+ padded_img_mask,
1856
+ padded_ref_img_mask,
1857
+ noise_rotary_emb,
1858
+ ref_img_rotary_emb,
1859
+ l_effective_ref_img_len,
1860
+ l_effective_img_len,
1861
+ temb
1862
+ ):
1863
+ batch_size = len(hidden_states)
1864
+ max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)])
1865
+
1866
+ hidden_states = self.x_embedder(hidden_states)
1867
+ ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)
1868
+
1869
+ for i in range(batch_size):
1870
+ shift = 0
1871
+ for j, ref_img_len in enumerate(l_effective_ref_img_len[i]):
1872
+ ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j]
1873
+ shift += ref_img_len
1874
+
1875
+ for layer in self.noise_refiner:
1876
+ hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
1877
+
1878
+ flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len))
1879
+ num_ref_images = len(flat_l_effective_ref_img_len)
1880
+ max_ref_img_len = max(flat_l_effective_ref_img_len)
1881
+
1882
+ batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool)
1883
+ batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size)
1884
+ batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype)
1885
+ batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype)
1886
+
1887
+ # sequence of ref imgs to batch
1888
+ idx = 0
1889
+ for i in range(batch_size):
1890
+ shift = 0
1891
+ for ref_img_len in l_effective_ref_img_len[i]:
1892
+ batch_ref_img_mask[idx, :ref_img_len] = True
1893
+ batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len]
1894
+ batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len]
1895
+ batch_temb[idx] = temb[i]
1896
+ shift += ref_img_len
1897
+ idx += 1
1898
+
1899
+ # refine ref imgs separately
1900
+ for layer in self.ref_image_refiner:
1901
+ batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb)
1902
+
1903
+ # batch of ref imgs to sequence
1904
+ idx = 0
1905
+ for i in range(batch_size):
1906
+ shift = 0
1907
+ for ref_img_len in l_effective_ref_img_len[i]:
1908
+ ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len]
1909
+ shift += ref_img_len
1910
+ idx += 1
1911
+
1912
+ combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size)
1913
+ for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)):
1914
+ combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)]
1915
+ combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len]
1916
+
1917
+ return combined_img_hidden_states
1918
+
1919
+ def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states):
1920
+ batch_size = len(hidden_states)
1921
+ p = self.config.patch_size
1922
+ device = hidden_states[0].device
1923
+
1924
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
1925
+ l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes]
1926
+
1927
+ if ref_image_hidden_states is not None:
1928
+ ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states]
1929
+ l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes]
1930
+ else:
1931
+ ref_img_sizes = [None for _ in range(batch_size)]
1932
+ l_effective_ref_img_len = [[0] for _ in range(batch_size)]
1933
+
1934
+ max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len])
1935
+ max_img_len = max(l_effective_img_len)
1936
+
1937
+ # ref image patch embeddings
1938
+ flat_ref_img_hidden_states = []
1939
+ for i in range(batch_size):
1940
+ if ref_img_sizes[i] is not None:
1941
+ imgs = []
1942
+ for ref_img in ref_image_hidden_states[i]:
1943
+ C, H, W = ref_img.size()
1944
+ ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
1945
+ imgs.append(ref_img)
1946
+
1947
+ img = torch.cat(imgs, dim=0)
1948
+ flat_ref_img_hidden_states.append(img)
1949
+ else:
1950
+ flat_ref_img_hidden_states.append(None)
1951
+
1952
+ # image patch embeddings
1953
+ flat_hidden_states = []
1954
+ for i in range(batch_size):
1955
+ img = hidden_states[i]
1956
+ C, H, W = img.size()
1957
+
1958
+ img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p)
1959
+ flat_hidden_states.append(img)
1960
+
1961
+ padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
1962
+ padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device)
1963
+ for i in range(batch_size):
1964
+ if ref_img_sizes[i] is not None:
1965
+ padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i]
1966
+ padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True
1967
+
1968
+ padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype)
1969
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
1970
+ for i in range(batch_size):
1971
+ padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i]
1972
+ padded_img_mask[i, :l_effective_img_len[i]] = True
1973
+
1974
+ return (
1975
+ padded_hidden_states,
1976
+ padded_ref_img_hidden_states,
1977
+ padded_img_mask,
1978
+ padded_ref_img_mask,
1979
+ l_effective_ref_img_len,
1980
+ l_effective_img_len,
1981
+ ref_img_sizes,
1982
+ img_sizes,
1983
+ )
1984
+
1985
+ def forward(
1986
+ self,
1987
+ hidden_states: Union[torch.Tensor, List[torch.Tensor]],
1988
+ timestep: torch.Tensor,
1989
+ text_hidden_states: torch.Tensor,
1990
+ freqs_cis: torch.Tensor,
1991
+ text_attention_mask: torch.Tensor,
1992
+ ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None,
1993
+ attention_kwargs: Optional[Dict[str, Any]] = None,
1994
+ return_dict: bool = False,
1995
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
1996
+ if attention_kwargs is not None:
1997
+ attention_kwargs = attention_kwargs.copy()
1998
+ lora_scale = attention_kwargs.pop("scale", 1.0)
1999
+ else:
2000
+ lora_scale = 1.0
2001
+
2002
+ if USE_PEFT_BACKEND:
2003
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
2004
+ scale_lora_layers(self, lora_scale)
2005
+ else:
2006
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
2007
+ logger.warning(
2008
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
2009
+ )
2010
+
2011
+ # 1. Condition, positional & patch embedding
2012
+ batch_size = len(hidden_states)
2013
+ is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor)
2014
+
2015
+ if is_hidden_states_tensor:
2016
+ assert hidden_states.ndim == 4
2017
+ hidden_states = [_hidden_states for _hidden_states in hidden_states]
2018
+
2019
+ device = hidden_states[0].device
2020
+
2021
+ temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
2022
+
2023
+ (
2024
+ hidden_states,
2025
+ ref_image_hidden_states,
2026
+ img_mask,
2027
+ ref_img_mask,
2028
+ l_effective_ref_img_len,
2029
+ l_effective_img_len,
2030
+ ref_img_sizes,
2031
+ img_sizes,
2032
+ ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
2033
+
2034
+ (
2035
+ context_rotary_emb,
2036
+ ref_img_rotary_emb,
2037
+ noise_rotary_emb,
2038
+ rotary_emb,
2039
+ encoder_seq_lengths,
2040
+ seq_lengths,
2041
+ ) = self.rope_embedder(
2042
+ freqs_cis,
2043
+ text_attention_mask,
2044
+ l_effective_ref_img_len,
2045
+ l_effective_img_len,
2046
+ ref_img_sizes,
2047
+ img_sizes,
2048
+ device,
2049
+ )
2050
+
2051
+ # 2. Context refinement
2052
+ for layer in self.context_refiner:
2053
+ text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
2054
+
2055
+ combined_img_hidden_states = self.img_patch_embed_and_refine(
2056
+ hidden_states,
2057
+ ref_image_hidden_states,
2058
+ img_mask,
2059
+ ref_img_mask,
2060
+ noise_rotary_emb,
2061
+ ref_img_rotary_emb,
2062
+ l_effective_ref_img_len,
2063
+ l_effective_img_len,
2064
+ temb,
2065
+ )
2066
+
2067
+ # 3. Joint Transformer blocks
2068
+ max_seq_len = max(seq_lengths)
2069
+
2070
+ attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
2071
+ joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
2072
+ for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
2073
+ attention_mask[i, :seq_len] = True
2074
+ joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len]
2075
+ joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len]
2076
+
2077
+ hidden_states = joint_hidden_states
2078
+
2079
+ for layer_idx, layer in enumerate(self.layers):
2080
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2081
+ hidden_states = self._gradient_checkpointing_func(
2082
+ layer, hidden_states, attention_mask, rotary_emb, temb
2083
+ )
2084
+ else:
2085
+ hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
2086
+
2087
+ # 4. Output norm & projection
2088
+ hidden_states = self.norm_out(hidden_states, temb)
2089
+
2090
+ p = self.config.patch_size
2091
+ output = []
2092
+ for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)):
2093
+ height, width = img_size
2094
+ output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p))
2095
+ if is_hidden_states_tensor:
2096
+ output = torch.stack(output, dim=0)
2097
+
2098
+ if USE_PEFT_BACKEND:
2099
+ # remove `lora_scale` from each PEFT layer
2100
+ unscale_lora_layers(self, lora_scale)
2101
+
2102
+ if not return_dict:
2103
+ return output
2104
+ return Transformer2DModelOutput(sample=output)