Spaces:
Build error
Build error
# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import warnings | |
import itertools | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.models.embeddings import Timesteps | |
from ..embeddings import TimestepEmbedding | |
from .components import swiglu | |
try: | |
# from apex.normalization import FusedRMSNorm | |
# from flash_attn.ops.rms_norm import RMSNorm as FusedRMSNorm | |
# from flash_attn.ops.triton.layer_norm import RMSNorm as FusedRMSNorm | |
from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm | |
FUSEDRMSNORM_AVALIBLE = True | |
except ImportError: | |
FUSEDRMSNORM_AVALIBLE = False | |
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") | |
try: | |
from flash_attn.ops.activations import swiglu as fused_swiglu | |
FUSEDSWIGLU_AVALIBLE = True | |
except ImportError: | |
FUSEDSWIGLU_AVALIBLE = False | |
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") | |
class LuminaRMSNormZero(nn.Module): | |
""" | |
Norm layer adaptive RMS normalization zero. | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
""" | |
def __init__( | |
self, | |
embedding_dim: int, | |
norm_eps: float, | |
norm_elementwise_affine: bool, | |
use_fused_rms_norm: bool = False, | |
): | |
super().__init__() | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear( | |
min(embedding_dim, 1024), | |
4 * embedding_dim, | |
bias=True, | |
) | |
if use_fused_rms_norm: | |
assert FUSEDRMSNORM_AVALIBLE | |
self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps) | |
else: | |
self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps) | |
def forward( | |
self, | |
x: torch.Tensor, | |
emb: Optional[torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
emb = self.linear(self.silu(emb)) | |
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) | |
x = self.norm(x) * (1 + scale_msa[:, None]) | |
# x_norm = self.norm(x) | |
# print(f"{x.shape=} {x.dtype=} {x_norm.shape=} {x_norm.dtype=}") | |
# print(f"{scale_msa.shape=} {scale_msa.dtype=}") | |
# print(f"{scale_msa[:, None].shape=} {scale_msa[:, None].dtype=}") | |
# x = x_norm * (1 + scale_msa[:, None]) | |
return x, gate_msa, scale_mlp, gate_mlp | |
class LuminaLayerNormContinuous(nn.Module): | |
def __init__( | |
self, | |
embedding_dim: int, | |
conditioning_embedding_dim: int, | |
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters | |
# because the output is immediately scaled and shifted by the projected conditioning embeddings. | |
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. | |
# However, this is how it was implemented in the original code, and it's rather likely you should | |
# set `elementwise_affine` to False. | |
elementwise_affine=True, | |
eps=1e-5, | |
bias=True, | |
norm_type="layer_norm", | |
out_dim: Optional[int] = None, | |
use_fused_rms_norm: bool = False | |
): | |
super().__init__() | |
# AdaLN | |
self.silu = nn.SiLU() | |
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) | |
if norm_type == "layer_norm": | |
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) | |
elif norm_type == "rms_norm": | |
if use_fused_rms_norm: | |
assert FUSEDRMSNORM_AVALIBLE | |
self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) | |
else: | |
self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) | |
else: | |
raise ValueError(f"unknown norm_type {norm_type}") | |
self.linear_2 = None | |
if out_dim is not None: | |
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) | |
def forward( | |
self, | |
x: torch.Tensor, | |
conditioning_embedding: torch.Tensor, | |
) -> torch.Tensor: | |
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) | |
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) | |
scale = emb | |
x = self.norm(x) * (1 + scale)[:, None, :] | |
if self.linear_2 is not None: | |
x = self.linear_2(x) | |
return x | |
class LuminaFeedForward(nn.Module): | |
r""" | |
A feed-forward layer. | |
Parameters: | |
hidden_size (`int`): | |
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's | |
hidden representations. | |
intermediate_size (`int`): The intermediate dimension of the feedforward layer. | |
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple | |
of this value. | |
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden | |
dimension. Defaults to None. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
inner_dim: int, | |
multiple_of: Optional[int] = 256, | |
ffn_dim_multiplier: Optional[float] = None, | |
use_fused_swiglu: bool = False | |
): | |
super().__init__() | |
self.use_fused_swiglu = use_fused_swiglu | |
if use_fused_swiglu: | |
assert FUSEDSWIGLU_AVALIBLE | |
self.swiglu = fused_swiglu | |
else: | |
self.swiglu = swiglu | |
# custom hidden_size factor multiplier | |
if ffn_dim_multiplier is not None: | |
inner_dim = int(ffn_dim_multiplier * inner_dim) | |
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) | |
self.linear_1 = nn.Linear( | |
dim, | |
inner_dim, | |
bias=False, | |
) | |
self.linear_2 = nn.Linear( | |
inner_dim, | |
dim, | |
bias=False, | |
) | |
self.linear_3 = nn.Linear( | |
dim, | |
inner_dim, | |
bias=False, | |
) | |
def forward(self, x): | |
h1, h2 = self.linear_1(x), self.linear_3(x) | |
return self.linear_2(self.swiglu(h1, h2)) | |
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): | |
def __init__( | |
self, | |
hidden_size: int = 4096, | |
text_feat_dim: int = 2048, | |
frequency_embedding_size: int = 256, | |
norm_eps: float = 1e-5, | |
timestep_scale: float = 1.0, | |
use_fused_rms_norm: bool = False | |
) -> None: | |
super().__init__() | |
self.time_proj = Timesteps( | |
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale | |
) | |
self.timestep_embedder = TimestepEmbedding( | |
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) | |
) | |
if use_fused_rms_norm: | |
assert FUSEDRMSNORM_AVALIBLE | |
RMSNorm = FusedRMSNorm | |
else: | |
RMSNorm = nn.RMSNorm | |
self.caption_embedder = nn.Sequential( | |
RMSNorm(text_feat_dim, eps=norm_eps), | |
nn.Linear(text_feat_dim, hidden_size, bias=True), | |
) | |
self._initialize_weights() | |
def _initialize_weights(self): | |
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) | |
nn.init.zeros_(self.caption_embedder[1].bias) | |
def forward( | |
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
timestep_proj = self.time_proj(timestep).to(dtype=dtype) | |
time_embed = self.timestep_embedder(timestep_proj) | |
caption_embed = self.caption_embedder(text_hidden_states) | |
return time_embed, caption_embed |