sienna223's picture
init
119e1fd
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import itertools
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import Timesteps
from ..embeddings import TimestepEmbedding
from .components import swiglu
try:
# from apex.normalization import FusedRMSNorm
# from flash_attn.ops.rms_norm import RMSNorm as FusedRMSNorm
# from flash_attn.ops.triton.layer_norm import RMSNorm as FusedRMSNorm
from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm
FUSEDRMSNORM_AVALIBLE = True
except ImportError:
FUSEDRMSNORM_AVALIBLE = False
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
try:
from flash_attn.ops.activations import swiglu as fused_swiglu
FUSEDSWIGLU_AVALIBLE = True
except ImportError:
FUSEDSWIGLU_AVALIBLE = False
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")
class LuminaRMSNormZero(nn.Module):
"""
Norm layer adaptive RMS normalization zero.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self,
embedding_dim: int,
norm_eps: float,
norm_elementwise_affine: bool,
use_fused_rms_norm: bool = False,
):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(
min(embedding_dim, 1024),
4 * embedding_dim,
bias=True,
)
if use_fused_rms_norm:
assert FUSEDRMSNORM_AVALIBLE
self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps)
else:
self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
# x_norm = self.norm(x)
# print(f"{x.shape=} {x.dtype=} {x_norm.shape=} {x_norm.dtype=}")
# print(f"{scale_msa.shape=} {scale_msa.dtype=}")
# print(f"{scale_msa[:, None].shape=} {scale_msa[:, None].dtype=}")
# x = x_norm * (1 + scale_msa[:, None])
return x, gate_msa, scale_mlp, gate_mlp
class LuminaLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
out_dim: Optional[int] = None,
use_fused_rms_norm: bool = False
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
if use_fused_rms_norm:
assert FUSEDRMSNORM_AVALIBLE
self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
self.linear_2 = None
if out_dim is not None:
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
scale = emb
x = self.norm(x) * (1 + scale)[:, None, :]
if self.linear_2 is not None:
x = self.linear_2(x)
return x
class LuminaFeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
hidden_size (`int`):
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
hidden representations.
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
dimension. Defaults to None.
"""
def __init__(
self,
dim: int,
inner_dim: int,
multiple_of: Optional[int] = 256,
ffn_dim_multiplier: Optional[float] = None,
use_fused_swiglu: bool = False
):
super().__init__()
self.use_fused_swiglu = use_fused_swiglu
if use_fused_swiglu:
assert FUSEDSWIGLU_AVALIBLE
self.swiglu = fused_swiglu
else:
self.swiglu = swiglu
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
self.linear_1 = nn.Linear(
dim,
inner_dim,
bias=False,
)
self.linear_2 = nn.Linear(
inner_dim,
dim,
bias=False,
)
self.linear_3 = nn.Linear(
dim,
inner_dim,
bias=False,
)
def forward(self, x):
h1, h2 = self.linear_1(x), self.linear_3(x)
return self.linear_2(self.swiglu(h1, h2))
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,
hidden_size: int = 4096,
text_feat_dim: int = 2048,
frequency_embedding_size: int = 256,
norm_eps: float = 1e-5,
timestep_scale: float = 1.0,
use_fused_rms_norm: bool = False
) -> None:
super().__init__()
self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale
)
self.timestep_embedder = TimestepEmbedding(
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
)
if use_fused_rms_norm:
assert FUSEDRMSNORM_AVALIBLE
RMSNorm = FusedRMSNorm
else:
RMSNorm = nn.RMSNorm
self.caption_embedder = nn.Sequential(
RMSNorm(text_feat_dim, eps=norm_eps),
nn.Linear(text_feat_dim, hidden_size, bias=True),
)
self._initialize_weights()
def _initialize_weights(self):
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02)
nn.init.zeros_(self.caption_embedder[1].bias)
def forward(
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
timestep_proj = self.time_proj(timestep).to(dtype=dtype)
time_embed = self.timestep_embedder(timestep_proj)
caption_embed = self.caption_embedder(text_hidden_states)
return time_embed, caption_embed