import warnings import itertools from typing import Any, Dict, List, Optional, Tuple, Union import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from diffusers.models.attention_processor import Attention from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin from diffusers.models.embeddings import get_1d_rotary_pos_embed from diffusers.models.activations import get_activation from diffusers.models.embeddings import Timesteps import importlib.util import sys # The package importlib_metadata is in a different place, depending on the python version. if sys.version_info < (3, 8): import importlib_metadata else: import importlib.metadata as importlib_metadata def _is_package_available(pkg_name: str): pkg_exists = importlib.util.find_spec(pkg_name) is not None pkg_version = "N/A" if pkg_exists: try: pkg_version = importlib_metadata.version(pkg_name) except (ImportError, importlib_metadata.PackageNotFoundError): pkg_exists = False return pkg_exists, pkg_version _triton_available, _triton_version = _is_package_available("triton") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") def is_triton_available(): return _triton_available def is_flash_attn_available(): return _flash_attn_available if is_triton_available(): # from ...ops.triton.layer_norm import RMSNorm import triton import triton.language as tl from typing import Callable def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): def decorator(*args, **kwargs): if cuda_amp_deprecated: kwargs["device_type"] = "cuda" return dec(*args, **kwargs) return decorator if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] deprecated = True from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] else: deprecated = False from torch.cuda.amp import custom_fwd, custom_bwd custom_fwd = custom_amp_decorator(custom_fwd, deprecated) custom_bwd = custom_amp_decorator(custom_bwd, deprecated) def triton_autotune_configs(): # Return configs with a valid warp count for the current device configs=[] # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024 max_threads_per_block=1024 # Default to warp size 32 if not defined by device warp_size=getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32) # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit warp_count=1 while warp_count*warp_size <= max_threads_per_block: configs.append(triton.Config({}, num_warps=warp_count)) warp_count*=2 return configs @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None}) @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None}) @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None}) @triton.jit def _layer_norm_fwd_1pass_kernel( X, # pointer to the input Y, # pointer to the output W, # pointer to the weights B, # pointer to the biases RESIDUAL, # pointer to the residual X1, W1, B1, Y1, RESIDUAL_OUT, # pointer to the residual ROWSCALE, SEEDS, # Dropout seeds for each row DROPOUT_MASK, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, # Dropout probability zero_centered_weight, # If true, add 1.0 to the weight IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr, HAS_W1: tl.constexpr, HAS_B1: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. row = tl.program_id(0) X += row * stride_x_row Y += row * stride_y_row if HAS_RESIDUAL: RESIDUAL += row * stride_res_row if STORE_RESIDUAL_OUT: RESIDUAL_OUT += row * stride_res_out_row if HAS_X1: X1 += row * stride_x1_row if HAS_W1: Y1 += row * stride_y1_row # Compute mean and variance cols = tl.arange(0, BLOCK_N) x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) x *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N) if HAS_X1: x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + M + row).to(tl.float32) x1 *= rowscale if HAS_DROPOUT: # Compute dropout mask # 7 rounds is good enough, and reduces register pressure keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0) if STORE_DROPOUT_MASK: tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N) x += x1 if HAS_RESIDUAL: residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32) x += residual if STORE_RESIDUAL_OUT: tl.store(RESIDUAL_OUT + cols, x, mask=cols < N) if not IS_RMS_NORM: mean = tl.sum(x, axis=0) / N tl.store(Mean + row, mean) xbar = tl.where(cols < N, x - mean, 0.0) var = tl.sum(xbar * xbar, axis=0) / N else: xbar = tl.where(cols < N, x, 0.0) var = tl.sum(xbar * xbar, axis=0) / N rstd = 1 / tl.sqrt(var + eps) tl.store(Rstd + row, rstd) # Normalize and apply linear transformation mask = cols < N w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if HAS_BIAS: b = tl.load(B + cols, mask=mask).to(tl.float32) x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd y = x_hat * w + b if HAS_BIAS else x_hat * w # Write output tl.store(Y + cols, y, mask=mask) if HAS_W1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 if HAS_B1: b1 = tl.load(B1 + cols, mask=mask).to(tl.float32) y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1 tl.store(Y1 + cols, y1, mask=mask) def _layer_norm_fwd( x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None, dropout_p=0.0, rowscale=None, out_dtype=None, residual_dtype=None, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None ): if residual is not None: residual_dtype = residual.dtype M, N = x.shape assert x.stride(-1) == 1 if residual is not None: assert residual.stride(-1) == 1 assert residual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if x1 is not None: assert x1.shape == x.shape assert rowscale is None assert x1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output if out is None: out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype) else: assert out.shape == x.shape assert out.stride(-1) == 1 if weight1 is not None: y1 = torch.empty_like(out) assert y1.stride(-1) == 1 else: y1 = None if ( residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None ): if residual_out is None: residual_out = torch.empty( M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype ) else: assert residual_out.shape == x.shape assert residual_out.stride(-1) == 1 else: residual_out = None mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None rstd = torch.empty((M,), dtype=torch.float32, device=x.device) if dropout_p > 0.0: seeds = torch.randint( 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64 ) else: seeds = None if return_dropout_mask and dropout_p > 0.0: dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool) else: dropout_mask = None # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") with torch.cuda.device(x.device.index): _layer_norm_fwd_1pass_kernel[(M,)]( x, out, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds, dropout_mask, mean, rstd, x.stride(0), out.stride(0), residual.stride(0) if residual is not None else 0, residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0, y1.stride(0) if y1 is not None else 0, M, N, eps, dropout_p, zero_centered_weight, is_rms_norm, BLOCK_N, residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0, dropout_mask is not None, rowscale is not None, ) # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0 if dropout_mask is not None and x1 is not None: dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0) else: dropout_mask1 = None return ( out, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1, ) @triton.autotune( configs=triton_autotune_configs(), key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"], ) # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None}) # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None}) @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None}) @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None}) @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None}) @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None}) @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None}) @triton.jit def _layer_norm_bwd_kernel( X, # pointer to the input W, # pointer to the weights B, # pointer to the biases Y, # pointer to the output to be recomputed DY, # pointer to the output gradient DX, # pointer to the input gradient DW, # pointer to the partial sum of weights gradient DB, # pointer to the partial sum of biases gradient DRESIDUAL, W1, DY1, DX1, DW1, DB1, DRESIDUAL_IN, ROWSCALE, SEEDS, Mean, # pointer to the mean Rstd, # pointer to the 1/std stride_x_row, # how much to increase the pointer when moving by 1 row stride_y_row, stride_dy_row, stride_dx_row, stride_dres_row, stride_dy1_row, stride_dx1_row, stride_dres_in_row, M, # number of rows in X N, # number of columns in X eps, # epsilon to avoid division by zero dropout_p, zero_centered_weight, rows_per_program, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_DY1: tl.constexpr, HAS_DX1: tl.constexpr, HAS_B1: tl.constexpr, RECOMPUTE_OUTPUT: tl.constexpr, ): # Map the program id to the elements of X, DX, and DY it should compute. row_block_id = tl.program_id(0) row_start = row_block_id * rows_per_program # Do not early exit if row_start >= M, because we need to write DW and DB cols = tl.arange(0, BLOCK_N) mask = cols < N X += row_start * stride_x_row if HAS_DRESIDUAL: DRESIDUAL += row_start * stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += row_start * stride_dres_in_row DY += row_start * stride_dy_row DX += row_start * stride_dx_row if HAS_DY1: DY1 += row_start * stride_dy1_row if HAS_DX1: DX1 += row_start * stride_dx1_row if RECOMPUTE_OUTPUT: Y += row_start * stride_y_row w = tl.load(W + cols, mask=mask).to(tl.float32) if zero_centered_weight: w += 1.0 if RECOMPUTE_OUTPUT and HAS_BIAS: b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) if HAS_DY1: w1 = tl.load(W1 + cols, mask=mask).to(tl.float32) if zero_centered_weight: w1 += 1.0 dw = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_BIAS: db = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_DY1: dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32) if HAS_B1: db1 = tl.zeros((BLOCK_N,), dtype=tl.float32) row_end = min((row_block_id + 1) * rows_per_program, M) for row in range(row_start, row_end): # Load data to SRAM x = tl.load(X + cols, mask=mask, other=0).to(tl.float32) dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32) if HAS_DY1: dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32) if not IS_RMS_NORM: mean = tl.load(Mean + row) rstd = tl.load(Rstd + row) # Compute dx xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd xhat = tl.where(mask, xhat, 0.0) if RECOMPUTE_OUTPUT: y = xhat * w + b if HAS_BIAS else xhat * w tl.store(Y + cols, y, mask=mask) wdy = w * dy dw += dy * xhat if HAS_BIAS: db += dy if HAS_DY1: wdy += w1 * dy1 dw1 += dy1 * xhat if HAS_B1: db1 += dy1 if not IS_RMS_NORM: c1 = tl.sum(xhat * wdy, axis=0) / N c2 = tl.sum(wdy, axis=0) / N dx = (wdy - (xhat * c1 + c2)) * rstd else: c1 = tl.sum(xhat * wdy, axis=0) / N dx = (wdy - xhat * c1) * rstd if HAS_DRESIDUAL: dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32) dx += dres # Write dx if STORE_DRESIDUAL: tl.store(DRESIDUAL_IN + cols, dx, mask=mask) if HAS_DX1: if HAS_DROPOUT: keep_mask = ( tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p ) dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) else: dx1 = dx tl.store(DX1 + cols, dx1, mask=mask) if HAS_DROPOUT: keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0) if HAS_ROWSCALE: rowscale = tl.load(ROWSCALE + row).to(tl.float32) dx *= rowscale tl.store(DX + cols, dx, mask=mask) X += stride_x_row if HAS_DRESIDUAL: DRESIDUAL += stride_dres_row if STORE_DRESIDUAL: DRESIDUAL_IN += stride_dres_in_row if RECOMPUTE_OUTPUT: Y += stride_y_row DY += stride_dy_row DX += stride_dx_row if HAS_DY1: DY1 += stride_dy1_row if HAS_DX1: DX1 += stride_dx1_row tl.store(DW + row_block_id * N + cols, dw, mask=mask) if HAS_BIAS: tl.store(DB + row_block_id * N + cols, db, mask=mask) if HAS_DY1: tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask) if HAS_B1: tl.store(DB1 + row_block_id * N + cols, db1, mask=mask) def _layer_norm_bwd( dy, x, weight, bias, eps, mean, rstd, dresidual=None, dy1=None, weight1=None, bias1=None, seeds=None, dropout_p=0.0, rowscale=None, has_residual=False, has_x1=False, zero_centered_weight=False, is_rms_norm=False, x_dtype=None, recompute_output=False, ): M, N = x.shape assert x.stride(-1) == 1 assert dy.stride(-1) == 1 assert dy.shape == (M, N) if dresidual is not None: assert dresidual.stride(-1) == 1 assert dresidual.shape == (M, N) assert weight.shape == (N,) assert weight.stride(-1) == 1 if bias is not None: assert bias.stride(-1) == 1 assert bias.shape == (N,) if dy1 is not None: assert weight1 is not None assert dy1.shape == dy.shape assert dy1.stride(-1) == 1 if weight1 is not None: assert weight1.shape == (N,) assert weight1.stride(-1) == 1 if bias1 is not None: assert bias1.shape == (N,) assert bias1.stride(-1) == 1 if seeds is not None: assert seeds.is_contiguous() assert seeds.shape == (M if not has_x1 else M * 2,) if rowscale is not None: assert rowscale.is_contiguous() assert rowscale.shape == (M,) # allocate output dx = ( torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device) ) dresidual_in = ( torch.empty_like(x) if has_residual and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1) else None ) dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None if recompute_output: assert weight1 is None, "recompute_output is not supported with parallel LayerNorm" # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) if N > BLOCK_N: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the # latency of the gmem reads/writes, but will increase the time of summing up dw / db. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8 _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device) _db = ( torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None ) _dw1 = torch.empty_like(_dw) if weight1 is not None else None _db1 = torch.empty_like(_db) if bias1 is not None else None rows_per_program = math.ceil(M / sm_count) grid = (sm_count,) with torch.cuda.device(x.device.index): _layer_norm_bwd_kernel[grid]( x, weight, bias, y, dy, dx, _dw, _db, dresidual, weight1, dy1, dx1, _dw1, _db1, dresidual_in, rowscale, seeds, mean, rstd, x.stride(0), 0 if not recompute_output else y.stride(0), dy.stride(0), dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, dy1.stride(0) if dy1 is not None else 0, dx1.stride(0) if dx1 is not None else 0, dresidual_in.stride(0) if dresidual_in is not None else 0, M, N, eps, dropout_p, zero_centered_weight, rows_per_program, is_rms_norm, BLOCK_N, dresidual is not None, dresidual_in is not None, bias is not None, dropout_p > 0.0, ) dw = _dw.sum(0).to(weight.dtype) db = _db.sum(0).to(bias.dtype) if bias is not None else None dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None # Don't need to compute dresidual_in separately in this case if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None: dresidual_in = dx if has_x1 and dropout_p == 0.0: dx1 = dx return ( (dx, dw, db, dresidual_in, dx1, dw1, db1) if not recompute_output else (dx, dw, db, dresidual_in, dx1, dw1, db1, y) ) class LayerNormFn(torch.autograd.Function): @staticmethod def forward( ctx, x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, is_rms_norm=False, return_dropout_mask=False, out=None, residual_out=None ): x_shape_og = x.shape # Check for zero sequence length if x.numel() == 0: ctx.zero_seq_length = True # Only save minimal required tensors for backward # ctx.save_for_backward(weight, bias, weight1, bias1) ctx.x_shape_og = x_shape_og ctx.weight_shape = weight.shape ctx.weight_dtype = weight.dtype ctx.weight_device = weight.device ctx.has_bias = bias is not None ctx.bias_shape = bias.shape if bias is not None else None ctx.bias_dtype = bias.dtype if bias is not None else None ctx.bias_device = bias.device if bias is not None else None ctx.has_weight1 = weight1 is not None ctx.weight1_shape = weight1.shape if weight1 is not None else None ctx.weight1_dtype = weight1.dtype if weight1 is not None else None ctx.weight1_device = weight1.device if weight1 is not None else None ctx.has_bias1 = bias1 is not None ctx.bias1_shape = bias1.shape if bias1 is not None else None ctx.bias1_dtype = bias1.dtype if bias1 is not None else None ctx.bias1_device = bias1.device if bias1 is not None else None ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.dropout_p = dropout_p # Handle output tensors with correct dtype y = x # Preserve input tensor properties y1 = torch.empty_like(x) if x1 is not None else None # Only create residual_out if prenorm is True residual_out = torch.empty(x.shape, dtype=torch.float32 if residual_in_fp32 else x.dtype, device=x.device) if prenorm else None # Handle dropout masks dropout_mask = None dropout_mask1 = None if return_dropout_mask: dropout_mask = torch.empty_like(x, dtype=torch.uint8) if x1 is not None: dropout_mask1 = torch.empty_like(x, dtype=torch.uint8) # Return based on configuration if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ((y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1)) else: return ((y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1)) ctx.zero_seq_length = False # reshape input data into 2D tensor x = x.reshape(-1, x.shape[-1]) if x.stride(-1) != 1: x = x.contiguous() if residual is not None: assert residual.shape == x_shape_og residual = residual.reshape(-1, residual.shape[-1]) if residual.stride(-1) != 1: residual = residual.contiguous() if x1 is not None: assert x1.shape == x_shape_og assert rowscale is None, "rowscale is not supported with parallel LayerNorm" x1 = x1.reshape(-1, x1.shape[-1]) if x1.stride(-1) != 1: x1 = x1.contiguous() weight = weight.contiguous() if bias is not None: bias = bias.contiguous() if weight1 is not None: weight1 = weight1.contiguous() if bias1 is not None: bias1 = bias1.contiguous() if rowscale is not None: rowscale = rowscale.reshape(-1).contiguous() residual_dtype = ( residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None) ) if out is not None: out = out.reshape(-1, out.shape[-1]) if residual_out is not None: residual_out = residual_out.reshape(-1, residual_out.shape[-1]) y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd( x, weight, bias, eps, residual, x1, weight1, bias1, dropout_p=dropout_p, rowscale=rowscale, residual_dtype=residual_dtype, zero_centered_weight=zero_centered_weight, is_rms_norm=is_rms_norm, return_dropout_mask=return_dropout_mask, out=out, residual_out=residual_out ) ctx.save_for_backward( residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd ) ctx.x_shape_og = x_shape_og ctx.eps = eps ctx.dropout_p = dropout_p ctx.is_rms_norm = is_rms_norm ctx.has_residual = residual is not None ctx.has_x1 = x1 is not None ctx.prenorm = prenorm ctx.x_dtype = x.dtype ctx.zero_centered_weight = zero_centered_weight y = y.reshape(x_shape_og) y1 = y1.reshape(x_shape_og) if y1 is not None else None residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None if not return_dropout_mask: if weight1 is None: return y if not prenorm else (y, residual_out) else: return (y, y1) if not prenorm else (y, y1, residual_out) else: if weight1 is None: return ( (y, dropout_mask, dropout_mask1) if not prenorm else (y, residual_out, dropout_mask, dropout_mask1) ) else: return ( (y, y1, dropout_mask, dropout_mask1) if not prenorm else (y, y1, residual_out, dropout_mask, dropout_mask1) ) @staticmethod def backward(ctx, dy, *args): if ctx.zero_seq_length: return ( torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device), torch.zeros(ctx.weight_shape, dtype=ctx.weight_dtype, device=ctx.weight_device), torch.zeros(ctx.bias_shape, dtype=ctx.bias_dtype, device=ctx.bias_device) if ctx.has_bias else None, torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_residual else None, torch.zeros(ctx.x_shape_og, dtype=dy.dtype, device=dy.device) if ctx.has_x1 and ctx.dropout_p > 0.0 else None, torch.zeros(ctx.weight1_shape, dtype=ctx.weight1_dtype, device=ctx.weight1_device) if ctx.has_weight1 else None, torch.zeros(ctx.bias1_shape, dtype=ctx.bias1_dtype, device=ctx.bias1_device) if ctx.has_bias1 else None, None, None, None, None, None, None, None, None, None, None, ) x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors dy = dy.reshape(-1, dy.shape[-1]) if dy.stride(-1) != 1: dy = dy.contiguous() assert dy.shape == x.shape if weight1 is not None: dy1, args = args[0], args[1:] dy1 = dy1.reshape(-1, dy1.shape[-1]) if dy1.stride(-1) != 1: dy1 = dy1.contiguous() assert dy1.shape == x.shape else: dy1 = None if ctx.prenorm: dresidual = args[0] dresidual = dresidual.reshape(-1, dresidual.shape[-1]) if dresidual.stride(-1) != 1: dresidual = dresidual.contiguous() assert dresidual.shape == x.shape else: dresidual = None dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd( dy, x, weight, bias, ctx.eps, mean, rstd, dresidual, dy1, weight1, bias1, seeds, ctx.dropout_p, rowscale, ctx.has_residual, ctx.has_x1, ctx.zero_centered_weight, ctx.is_rms_norm, x_dtype=ctx.x_dtype, ) return ( dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, dx1.reshape(ctx.x_shape_og) if dx1 is not None else None, dw1, db1, None, None, None, None, None, None, None, None, None, None, ) def rms_norm_fn( x, weight, bias, residual=None, x1=None, weight1=None, bias1=None, eps=1e-6, dropout_p=0.0, rowscale=None, prenorm=False, residual_in_fp32=False, zero_centered_weight=False, return_dropout_mask=False, out=None, residual_out=None ): return LayerNormFn.apply( x, weight, bias, residual, x1, weight1, bias1, eps, dropout_p, rowscale, prenorm, residual_in_fp32, zero_centered_weight, True, return_dropout_mask, out, residual_out ) class RMSNorm(torch.nn.Module): def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps if dropout_p > 0.0: self.drop = torch.nn.Dropout(dropout_p) else: self.drop = None self.zero_centered_weight = zero_centered_weight self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.reset_parameters() def reset_parameters(self): if not self.zero_centered_weight: torch.nn.init.ones_(self.weight) else: torch.nn.init.zeros_(self.weight) def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False): return rms_norm_fn( x, self.weight, self.bias, residual=residual, eps=self.eps, dropout_p=self.drop.p if self.drop is not None and self.training else 0.0, prenorm=prenorm, residual_in_fp32=residual_in_fp32, zero_centered_weight=self.zero_centered_weight, ) else: from torch.nn import RMSNorm warnings.warn("Cannot import triton, install triton to use fused RMSNorm for better performance") def swiglu(x, y): return F.silu(x.float(), inplace=False).to(x.dtype) * y logger = logging.get_logger(__name__) class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) self.initialize_weights() def initialize_weights(self): nn.init.normal_(self.linear_1.weight, std=0.02) nn.init.zeros_(self.linear_1.bias) nn.init.normal_(self.linear_2.weight, std=0.02) nn.init.zeros_(self.linear_2.bias) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Used for flux, cogvideox, hunyuan-dit x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Used for Stable Audio, OmniGen and CogView4 x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: # used for lumina # x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class OmniGen2RotaryPosEmbed(nn.Module): def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2): super().__init__() self.theta = theta self.axes_dim = axes_dim self.axes_lens = axes_lens self.patch_size = patch_size @staticmethod def get_freqs_cis(axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int) -> List[torch.Tensor]: freqs_cis = [] freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) return freqs_cis def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: device = ids.device if ids.device.type == "mps": ids = ids.to("cpu") result = [] for i in range(len(self.axes_dim)): freqs = freqs_cis[i].to(ids.device) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) return torch.cat(result, dim=-1).to(device) def forward( self, freqs_cis, attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device ): batch_size = len(attention_mask) p = self.patch_size encoder_seq_len = attention_mask.shape[1] l_effective_cap_len = attention_mask.sum(dim=1).tolist() seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] max_seq_len = max(seq_lengths) max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) max_img_len = max(l_effective_img_len) # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): # add text position ids position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") pe_shift = cap_seq_len pe_shift_len = cap_seq_len if ref_img_sizes[i] is not None: for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): H, W = ref_img_size ref_H_tokens, ref_W_tokens = H // p, W // p assert ref_H_tokens * ref_W_tokens == ref_img_len # add image position ids row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids pe_shift += max(ref_H_tokens, ref_W_tokens) pe_shift_len += ref_img_len H, W = img_sizes[i] H_tokens, W_tokens = H // p, W // p assert H_tokens * W_tokens == l_effective_img_len[i] row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() assert pe_shift_len + l_effective_img_len[i] == seq_len position_ids[i, pe_shift_len: seq_len, 0] = pe_shift position_ids[i, pe_shift_len: seq_len, 1] = row_ids position_ids[i, pe_shift_len: seq_len, 2] = col_ids # Get combined rotary embeddings freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) # create separate rotary embeddings for captions and images cap_freqs_cis = torch.zeros( batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) ref_img_freqs_cis = torch.zeros( batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) img_freqs_cis = torch.zeros( batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] return ( cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths, ) class LuminaRMSNormZero(nn.Module): """ Norm layer adaptive RMS normalization zero. Parameters: embedding_dim (`int`): The size of each embedding vector. """ def __init__( self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear( min(embedding_dim, 1024), 4 * embedding_dim, bias=True, ) self.norm = RMSNorm(embedding_dim, eps=norm_eps) def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp class LuminaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", out_dim: Optional[int] = None, ): super().__init__() # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") self.linear_2 = None if out_dim is not None: self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, x: torch.Tensor, conditioning_embedding: torch.Tensor, ) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) scale = emb x = self.norm(x) * (1 + scale)[:, None, :] if self.linear_2 is not None: x = self.linear_2(x) return x class LuminaFeedForward(nn.Module): r""" A feed-forward layer. Parameters: hidden_size (`int`): The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations. intermediate_size (`int`): The intermediate dimension of the feedforward layer. multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden dimension. Defaults to None. """ def __init__( self, dim: int, inner_dim: int, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, ): super().__init__() self.swiglu = swiglu # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: inner_dim = int(ffn_dim_multiplier * inner_dim) inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) self.linear_1 = nn.Linear( dim, inner_dim, bias=False, ) self.linear_2 = nn.Linear( inner_dim, dim, bias=False, ) self.linear_3 = nn.Linear( dim, inner_dim, bias=False, ) def forward(self, x): h1, h2 = self.linear_1(x), self.linear_3(x) return self.linear_2(self.swiglu(h1, h2)) class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, ) -> None: super().__init__() self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale ) self.timestep_embedder = TimestepEmbedding( in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) ) self.caption_embedder = nn.Sequential( RMSNorm(text_feat_dim, eps=norm_eps), nn.Linear(text_feat_dim, hidden_size, bias=True), ) self._initialize_weights() def _initialize_weights(self): nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) nn.init.zeros_(self.caption_embedder[1].bias) def forward( self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: timestep_proj = self.time_proj(timestep).to(dtype=dtype) time_embed = self.timestep_embedder(timestep_proj) caption_embed = self.caption_embedder(text_hidden_states) return time_embed, caption_embed class OmniGen2AttnProcessor: """ Processor for implementing scaled dot-product attention. This processor is optimized for PyTorch 2.0 and implements: - Flash attention with variable length sequences - Rotary position embeddings (RoPE) - Query-Key normalization - Proportional attention scaling Args: None Raises: ImportError: If PyTorch version is less than 2.0 """ def __init__(self) -> None: """Initialize the attention processor.""" if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " "Please upgrade PyTorch to version 2.0 or later." ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: """ Process attention computation with flash attention. Args: attn: Attention module hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) encoder_hidden_states: Encoder hidden states tensor attention_mask: Optional attention mask tensor image_rotary_emb: Optional rotary embeddings for image tokens base_sequence_length: Optional base sequence length for proportional attention Returns: torch.Tensor: Processed hidden states after attention computation """ batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Reshape tensors for attention computation query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply Query-Key normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply Rotary Position Embeddings if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) key = apply_rotary_emb(key, image_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Calculate attention scale if base_sequence_length is not None: softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale else: softmax_scale = attn.scale # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) if attention_mask is not None: attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) # explicitly repeat key and value to match query length, otherwise using enable_gqa=True results in MATH backend of sdpa in our test of pytorch2.6 key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, scale=softmax_scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states class OmniGen2TransformerBlock(nn.Module): """ Transformer block for OmniGen2 model. This block implements a transformer layer with: - Multi-head attention with flash attention - Feed-forward network with SwiGLU activation - RMS normalization - Optional modulation for conditional generation Args: dim: Dimension of the input and output tensors num_attention_heads: Number of attention heads num_kv_heads: Number of key-value heads multiple_of: Multiple of which the hidden dimension should be ffn_dim_multiplier: Multiplier for the feed-forward network dimension norm_eps: Epsilon value for normalization layers modulation: Whether to use modulation for conditional generation use_fused_rms_norm: Whether to use fused RMS normalization use_fused_swiglu: Whether to use fused SwiGLU activation """ def __init__( self, dim: int, num_attention_heads: int, num_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, ) -> None: """Initialize the transformer block.""" super().__init__() self.head_dim = dim // num_attention_heads self.modulation = modulation # Initialize attention layer self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=dim // num_attention_heads, qk_norm="rms_norm", heads=num_attention_heads, kv_heads=num_kv_heads, eps=1e-5, bias=False, out_bias=False, processor=OmniGen2AttnProcessor(), ) # Initialize feed-forward network self.feed_forward = LuminaFeedForward( dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) # Initialize normalization layers if modulation: self.norm1 = LuminaRMSNormZero( embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True, ) else: self.norm1 = RMSNorm(dim, eps=norm_eps) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) self.norm2 = RMSNorm(dim, eps=norm_eps) self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) self.initialize_weights() def initialize_weights(self) -> None: """ Initialize the weights of the transformer block. Uses Xavier uniform initialization for linear layers and zero initialization for biases. """ nn.init.xavier_uniform_(self.attn.to_q.weight) nn.init.xavier_uniform_(self.attn.to_k.weight) nn.init.xavier_uniform_(self.attn.to_v.weight) nn.init.xavier_uniform_(self.attn.to_out[0].weight) nn.init.xavier_uniform_(self.feed_forward.linear_1.weight) nn.init.xavier_uniform_(self.feed_forward.linear_2.weight) nn.init.xavier_uniform_(self.feed_forward.linear_3.weight) if self.modulation: nn.init.zeros_(self.norm1.linear.weight) nn.init.zeros_(self.norm1.linear.bias) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Forward pass of the transformer block. Args: hidden_states: Input hidden states tensor attention_mask: Attention mask tensor image_rotary_emb: Rotary embeddings for image tokens temb: Optional timestep embedding tensor Returns: torch.Tensor: Output hidden states after transformer block processing """ if self.modulation: if temb is None: raise ValueError("temb must be provided when modulation is enabled") norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: norm_hidden_states = self.norm1(hidden_states) attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) return hidden_states class OmniGen2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ OmniGen2 Transformer 2D Model. A transformer-based diffusion model for image generation with: - Patch-based image processing - Rotary position embeddings - Multi-head attention - Conditional generation support Args: patch_size: Size of image patches in_channels: Number of input channels out_channels: Number of output channels (defaults to in_channels) hidden_size: Size of hidden layers num_layers: Number of transformer layers num_refiner_layers: Number of refiner layers num_attention_heads: Number of attention heads num_kv_heads: Number of key-value heads multiple_of: Multiple of which the hidden dimension should be ffn_dim_multiplier: Multiplier for feed-forward network dimension norm_eps: Epsilon value for normalization layers axes_dim_rope: Dimensions for rotary position embeddings axes_lens: Lengths for rotary position embeddings text_feat_dim: Dimension of text features timestep_scale: Scale factor for timestep embeddings use_fused_rms_norm: Whether to use fused RMS normalization use_fused_swiglu: Whether to use fused SwiGLU activation """ _supports_gradient_checkpointing = True _no_split_modules = ["Omnigen2TransformerBlock"] _skip_layerwise_casting_patterns = ["x_embedder", "norm"] @register_to_config def __init__( self, patch_size: int = 2, in_channels: int = 16, out_channels: Optional[int] = None, hidden_size: int = 2304, num_layers: int = 26, num_refiner_layers: int = 2, num_attention_heads: int = 24, num_kv_heads: int = 8, multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, norm_eps: float = 1e-5, axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), axes_lens: Tuple[int, int, int] = (300, 512, 512), text_feat_dim: int = 1024, timestep_scale: float = 1.0, ) -> None: """Initialize the OmniGen2 transformer model.""" super().__init__() # Validate configuration if (hidden_size // num_attention_heads) != sum(axes_dim_rope): raise ValueError( f"hidden_size // num_attention_heads ({hidden_size // num_attention_heads}) " f"must equal sum(axes_dim_rope) ({sum(axes_dim_rope)})" ) self.out_channels = out_channels or in_channels # Initialize embeddings self.rope_embedder = OmniGen2RotaryPosEmbed( theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size, ) self.x_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=hidden_size, ) self.ref_image_patch_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=hidden_size, ) self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( hidden_size=hidden_size, text_feat_dim=text_feat_dim, norm_eps=norm_eps, timestep_scale=timestep_scale, ) # Initialize transformer blocks self.noise_refiner = nn.ModuleList([ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, ) for _ in range(num_refiner_layers) ]) self.ref_image_refiner = nn.ModuleList([ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, ) for _ in range(num_refiner_layers) ]) self.context_refiner = nn.ModuleList( [ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, ) for _ in range(num_refiner_layers) ] ) # 3. Transformer blocks self.layers = nn.ModuleList( [ OmniGen2TransformerBlock( hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, ) for _ in range(num_layers) ] ) # 4. Output norm & projection self.norm_out = LuminaLayerNormContinuous( embedding_dim=hidden_size, conditioning_embedding_dim=min(hidden_size, 1024), elementwise_affine=False, eps=1e-6, bias=True, out_dim=patch_size * patch_size * self.out_channels, ) # Add learnable embeddings to distinguish different images self.image_index_embedding = nn.Parameter(torch.randn(5, hidden_size)) # support max 5 ref images self.gradient_checkpointing = False self.initialize_weights() def initialize_weights(self) -> None: """ Initialize the weights of the model. Uses Xavier uniform initialization for linear layers. """ nn.init.xavier_uniform_(self.x_embedder.weight) nn.init.constant_(self.x_embedder.bias, 0.0) nn.init.xavier_uniform_(self.ref_image_patch_embedder.weight) nn.init.constant_(self.ref_image_patch_embedder.bias, 0.0) nn.init.zeros_(self.norm_out.linear_1.weight) nn.init.zeros_(self.norm_out.linear_1.bias) nn.init.zeros_(self.norm_out.linear_2.weight) nn.init.zeros_(self.norm_out.linear_2.bias) nn.init.normal_(self.image_index_embedding, std=0.02) def img_patch_embed_and_refine( self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb ): batch_size = len(hidden_states) max_combined_img_len = max([img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]) hidden_states = self.x_embedder(hidden_states) ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) for i in range(batch_size): shift = 0 for j, ref_img_len in enumerate(l_effective_ref_img_len[i]): ref_image_hidden_states[i, shift:shift + ref_img_len, :] = ref_image_hidden_states[i, shift:shift + ref_img_len, :] + self.image_index_embedding[j] shift += ref_img_len for layer in self.noise_refiner: hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) flat_l_effective_ref_img_len = list(itertools.chain(*l_effective_ref_img_len)) num_ref_images = len(flat_l_effective_ref_img_len) max_ref_img_len = max(flat_l_effective_ref_img_len) batch_ref_img_mask = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, dtype=torch.bool) batch_ref_image_hidden_states = ref_image_hidden_states.new_zeros(num_ref_images, max_ref_img_len, self.config.hidden_size) batch_ref_img_rotary_emb = hidden_states.new_zeros(num_ref_images, max_ref_img_len, ref_img_rotary_emb.shape[-1], dtype=ref_img_rotary_emb.dtype) batch_temb = temb.new_zeros(num_ref_images, *temb.shape[1:], dtype=temb.dtype) # sequence of ref imgs to batch idx = 0 for i in range(batch_size): shift = 0 for ref_img_len in l_effective_ref_img_len[i]: batch_ref_img_mask[idx, :ref_img_len] = True batch_ref_image_hidden_states[idx, :ref_img_len] = ref_image_hidden_states[i, shift:shift + ref_img_len] batch_ref_img_rotary_emb[idx, :ref_img_len] = ref_img_rotary_emb[i, shift:shift + ref_img_len] batch_temb[idx] = temb[i] shift += ref_img_len idx += 1 # refine ref imgs separately for layer in self.ref_image_refiner: batch_ref_image_hidden_states = layer(batch_ref_image_hidden_states, batch_ref_img_mask, batch_ref_img_rotary_emb, batch_temb) # batch of ref imgs to sequence idx = 0 for i in range(batch_size): shift = 0 for ref_img_len in l_effective_ref_img_len[i]: ref_image_hidden_states[i, shift:shift + ref_img_len] = batch_ref_image_hidden_states[idx, :ref_img_len] shift += ref_img_len idx += 1 combined_img_hidden_states = hidden_states.new_zeros(batch_size, max_combined_img_len, self.config.hidden_size) for i, (ref_img_len, img_len) in enumerate(zip(l_effective_ref_img_len, l_effective_img_len)): combined_img_hidden_states[i, :sum(ref_img_len)] = ref_image_hidden_states[i, :sum(ref_img_len)] combined_img_hidden_states[i, sum(ref_img_len):sum(ref_img_len) + img_len] = hidden_states[i, :img_len] return combined_img_hidden_states def flat_and_pad_to_seq(self, hidden_states, ref_image_hidden_states): batch_size = len(hidden_states) p = self.config.patch_size device = hidden_states[0].device img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] l_effective_img_len = [(H // p) * (W // p) for (H, W) in img_sizes] if ref_image_hidden_states is not None: ref_img_sizes = [[(img.size(1), img.size(2)) for img in imgs] if imgs is not None else None for imgs in ref_image_hidden_states] l_effective_ref_img_len = [[(ref_img_size[0] // p) * (ref_img_size[1] // p) for ref_img_size in _ref_img_sizes] if _ref_img_sizes is not None else [0] for _ref_img_sizes in ref_img_sizes] else: ref_img_sizes = [None for _ in range(batch_size)] l_effective_ref_img_len = [[0] for _ in range(batch_size)] max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) max_img_len = max(l_effective_img_len) # ref image patch embeddings flat_ref_img_hidden_states = [] for i in range(batch_size): if ref_img_sizes[i] is not None: imgs = [] for ref_img in ref_image_hidden_states[i]: C, H, W = ref_img.size() ref_img = rearrange(ref_img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) imgs.append(ref_img) img = torch.cat(imgs, dim=0) flat_ref_img_hidden_states.append(img) else: flat_ref_img_hidden_states.append(None) # image patch embeddings flat_hidden_states = [] for i in range(batch_size): img = hidden_states[i] C, H, W = img.size() img = rearrange(img, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=p, p2=p) flat_hidden_states.append(img) padded_ref_img_hidden_states = torch.zeros(batch_size, max_ref_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) padded_ref_img_mask = torch.zeros(batch_size, max_ref_img_len, dtype=torch.bool, device=device) for i in range(batch_size): if ref_img_sizes[i] is not None: padded_ref_img_hidden_states[i, :sum(l_effective_ref_img_len[i])] = flat_ref_img_hidden_states[i] padded_ref_img_mask[i, :sum(l_effective_ref_img_len[i])] = True padded_hidden_states = torch.zeros(batch_size, max_img_len, flat_hidden_states[0].shape[-1], device=device, dtype=flat_hidden_states[0].dtype) padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) for i in range(batch_size): padded_hidden_states[i, :l_effective_img_len[i]] = flat_hidden_states[i] padded_img_mask[i, :l_effective_img_len[i]] = True return ( padded_hidden_states, padded_ref_img_hidden_states, padded_img_mask, padded_ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) def forward( self, hidden_states: Union[torch.Tensor, List[torch.Tensor]], timestep: torch.Tensor, text_hidden_states: torch.Tensor, freqs_cis: torch.Tensor, text_attention_mask: torch.Tensor, ref_image_hidden_states: Optional[List[List[torch.Tensor]]] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = False, ) -> Union[torch.Tensor, Transformer2DModelOutput]: if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) # 1. Condition, positional & patch embedding batch_size = len(hidden_states) is_hidden_states_tensor = isinstance(hidden_states, torch.Tensor) if is_hidden_states_tensor: assert hidden_states.ndim == 4 hidden_states = [_hidden_states for _hidden_states in hidden_states] device = hidden_states[0].device temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype) ( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, ) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states) ( context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb, rotary_emb, encoder_seq_lengths, seq_lengths, ) = self.rope_embedder( freqs_cis, text_attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device, ) # 2. Context refinement for layer in self.context_refiner: text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb) combined_img_hidden_states = self.img_patch_embed_and_refine( hidden_states, ref_image_hidden_states, img_mask, ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, ) # 3. Joint Transformer blocks max_seq_len = max(seq_lengths) attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)): attention_mask[i, :seq_len] = True joint_hidden_states[i, :encoder_seq_len] = text_hidden_states[i, :encoder_seq_len] joint_hidden_states[i, encoder_seq_len:seq_len] = combined_img_hidden_states[i, :seq_len - encoder_seq_len] hidden_states = joint_hidden_states for layer_idx, layer in enumerate(self.layers): if torch.is_grad_enabled() and self.gradient_checkpointing: hidden_states = self._gradient_checkpointing_func( layer, hidden_states, attention_mask, rotary_emb, temb ) else: hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb) # 4. Output norm & projection hidden_states = self.norm_out(hidden_states, temb) p = self.config.patch_size output = [] for i, (img_size, img_len, seq_len) in enumerate(zip(img_sizes, l_effective_img_len, seq_lengths)): height, width = img_size output.append(rearrange(hidden_states[i][seq_len - img_len:seq_len], '(h w) (p1 p2 c) -> c (h p1) (w p2)', h=height // p, w=width // p, p1=p, p2=p)) if is_hidden_states_tensor: output = torch.stack(output, dim=0) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return output return Transformer2DModelOutput(sample=output)