# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings import itertools from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from diffusers.models.embeddings import Timesteps from ..embeddings import TimestepEmbedding from .components import swiglu try: # from apex.normalization import FusedRMSNorm # from flash_attn.ops.rms_norm import RMSNorm as FusedRMSNorm # from flash_attn.ops.triton.layer_norm import RMSNorm as FusedRMSNorm from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm FUSEDRMSNORM_AVALIBLE = True except ImportError: FUSEDRMSNORM_AVALIBLE = False warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") try: from flash_attn.ops.activations import swiglu as fused_swiglu FUSEDSWIGLU_AVALIBLE = True except ImportError: FUSEDSWIGLU_AVALIBLE = False warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") class LuminaRMSNormZero(nn.Module): """ Norm layer adaptive RMS normalization zero. Parameters: embedding_dim (`int`): The size of each embedding vector. """ def __init__( self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, use_fused_rms_norm: bool = False, ): super().__init__() self.silu = nn.SiLU() self.linear = nn.Linear( min(embedding_dim, 1024), 4 * embedding_dim, bias=True, ) if use_fused_rms_norm: assert FUSEDRMSNORM_AVALIBLE self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps) else: self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps) def forward( self, x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: emb = self.linear(self.silu(emb)) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) # x_norm = self.norm(x) # print(f"{x.shape=} {x.dtype=} {x_norm.shape=} {x_norm.dtype=}") # print(f"{scale_msa.shape=} {scale_msa.dtype=}") # print(f"{scale_msa[:, None].shape=} {scale_msa[:, None].dtype=}") # x = x_norm * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp class LuminaLayerNormContinuous(nn.Module): def __init__( self, embedding_dim: int, conditioning_embedding_dim: int, # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters # because the output is immediately scaled and shifted by the projected conditioning embeddings. # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. # However, this is how it was implemented in the original code, and it's rather likely you should # set `elementwise_affine` to False. elementwise_affine=True, eps=1e-5, bias=True, norm_type="layer_norm", out_dim: Optional[int] = None, use_fused_rms_norm: bool = False ): super().__init__() # AdaLN self.silu = nn.SiLU() self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) if norm_type == "layer_norm": self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) elif norm_type == "rms_norm": if use_fused_rms_norm: assert FUSEDRMSNORM_AVALIBLE self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) else: raise ValueError(f"unknown norm_type {norm_type}") self.linear_2 = None if out_dim is not None: self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) def forward( self, x: torch.Tensor, conditioning_embedding: torch.Tensor, ) -> torch.Tensor: # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) scale = emb x = self.norm(x) * (1 + scale)[:, None, :] if self.linear_2 is not None: x = self.linear_2(x) return x class LuminaFeedForward(nn.Module): r""" A feed-forward layer. Parameters: hidden_size (`int`): The dimensionality of the hidden layers in the model. This parameter determines the width of the model's hidden representations. intermediate_size (`int`): The intermediate dimension of the feedforward layer. multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple of this value. ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden dimension. Defaults to None. """ def __init__( self, dim: int, inner_dim: int, multiple_of: Optional[int] = 256, ffn_dim_multiplier: Optional[float] = None, use_fused_swiglu: bool = False ): super().__init__() self.use_fused_swiglu = use_fused_swiglu if use_fused_swiglu: assert FUSEDSWIGLU_AVALIBLE self.swiglu = fused_swiglu else: self.swiglu = swiglu # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: inner_dim = int(ffn_dim_multiplier * inner_dim) inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) self.linear_1 = nn.Linear( dim, inner_dim, bias=False, ) self.linear_2 = nn.Linear( inner_dim, dim, bias=False, ) self.linear_3 = nn.Linear( dim, inner_dim, bias=False, ) def forward(self, x): h1, h2 = self.linear_1(x), self.linear_3(x) return self.linear_2(self.swiglu(h1, h2)) class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__( self, hidden_size: int = 4096, text_feat_dim: int = 2048, frequency_embedding_size: int = 256, norm_eps: float = 1e-5, timestep_scale: float = 1.0, use_fused_rms_norm: bool = False ) -> None: super().__init__() self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale ) self.timestep_embedder = TimestepEmbedding( in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) ) if use_fused_rms_norm: assert FUSEDRMSNORM_AVALIBLE RMSNorm = FusedRMSNorm else: RMSNorm = nn.RMSNorm self.caption_embedder = nn.Sequential( RMSNorm(text_feat_dim, eps=norm_eps), nn.Linear(text_feat_dim, hidden_size, bias=True), ) self._initialize_weights() def _initialize_weights(self): nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) nn.init.zeros_(self.caption_embedder[1].bias) def forward( self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype ) -> Tuple[torch.Tensor, torch.Tensor]: timestep_proj = self.time_proj(timestep).to(dtype=dtype) time_embed = self.timestep_embedder(timestep_proj) caption_embed = self.caption_embedder(text_hidden_states) return time_embed, caption_embed