|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import warnings |
|
import itertools |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from diffusers.models.embeddings import Timesteps |
|
from ..embeddings import TimestepEmbedding |
|
from .components import swiglu |
|
|
|
try: |
|
|
|
|
|
|
|
from ...ops.triton.layer_norm import RMSNorm as FusedRMSNorm |
|
FUSEDRMSNORM_AVALIBLE = True |
|
except ImportError: |
|
FUSEDRMSNORM_AVALIBLE = False |
|
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") |
|
|
|
try: |
|
from flash_attn.ops.activations import swiglu as fused_swiglu |
|
FUSEDSWIGLU_AVALIBLE = True |
|
except ImportError: |
|
|
|
FUSEDSWIGLU_AVALIBLE = False |
|
warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation") |
|
|
|
class LuminaRMSNormZero(nn.Module): |
|
""" |
|
Norm layer adaptive RMS normalization zero. |
|
|
|
Parameters: |
|
embedding_dim (`int`): The size of each embedding vector. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
norm_eps: float, |
|
norm_elementwise_affine: bool, |
|
use_fused_rms_norm: bool = False, |
|
): |
|
super().__init__() |
|
self.silu = nn.SiLU() |
|
self.linear = nn.Linear( |
|
min(embedding_dim, 1024), |
|
4 * embedding_dim, |
|
bias=True, |
|
) |
|
if use_fused_rms_norm: |
|
assert FUSEDRMSNORM_AVALIBLE |
|
self.norm = FusedRMSNorm(embedding_dim, eps=norm_eps) |
|
else: |
|
self.norm = nn.RMSNorm(embedding_dim, eps=norm_eps) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
emb: Optional[torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
emb = self.linear(self.silu(emb)) |
|
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) |
|
x = self.norm(x) * (1 + scale_msa[:, None]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return x, gate_msa, scale_mlp, gate_mlp |
|
|
|
|
|
class LuminaLayerNormContinuous(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
conditioning_embedding_dim: int, |
|
|
|
|
|
|
|
|
|
|
|
elementwise_affine=True, |
|
eps=1e-5, |
|
bias=True, |
|
norm_type="layer_norm", |
|
out_dim: Optional[int] = None, |
|
use_fused_rms_norm: bool = False |
|
): |
|
super().__init__() |
|
|
|
|
|
self.silu = nn.SiLU() |
|
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias) |
|
|
|
if norm_type == "layer_norm": |
|
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias) |
|
elif norm_type == "rms_norm": |
|
if use_fused_rms_norm: |
|
assert FUSEDRMSNORM_AVALIBLE |
|
self.norm = FusedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) |
|
else: |
|
self.norm = nn.RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) |
|
else: |
|
raise ValueError(f"unknown norm_type {norm_type}") |
|
|
|
self.linear_2 = None |
|
if out_dim is not None: |
|
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
conditioning_embedding: torch.Tensor, |
|
) -> torch.Tensor: |
|
|
|
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype)) |
|
scale = emb |
|
x = self.norm(x) * (1 + scale)[:, None, :] |
|
|
|
if self.linear_2 is not None: |
|
x = self.linear_2(x) |
|
|
|
return x |
|
|
|
|
|
class LuminaFeedForward(nn.Module): |
|
r""" |
|
A feed-forward layer. |
|
|
|
Parameters: |
|
hidden_size (`int`): |
|
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's |
|
hidden representations. |
|
intermediate_size (`int`): The intermediate dimension of the feedforward layer. |
|
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple |
|
of this value. |
|
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden |
|
dimension. Defaults to None. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
inner_dim: int, |
|
multiple_of: Optional[int] = 256, |
|
ffn_dim_multiplier: Optional[float] = None, |
|
use_fused_swiglu: bool = False |
|
): |
|
super().__init__() |
|
self.use_fused_swiglu = use_fused_swiglu |
|
|
|
if use_fused_swiglu: |
|
assert FUSEDSWIGLU_AVALIBLE |
|
self.swiglu = fused_swiglu |
|
else: |
|
self.swiglu = swiglu |
|
|
|
|
|
if ffn_dim_multiplier is not None: |
|
inner_dim = int(ffn_dim_multiplier * inner_dim) |
|
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of) |
|
|
|
self.linear_1 = nn.Linear( |
|
dim, |
|
inner_dim, |
|
bias=False, |
|
) |
|
self.linear_2 = nn.Linear( |
|
inner_dim, |
|
dim, |
|
bias=False, |
|
) |
|
self.linear_3 = nn.Linear( |
|
dim, |
|
inner_dim, |
|
bias=False, |
|
) |
|
|
|
def forward(self, x): |
|
h1, h2 = self.linear_1(x), self.linear_3(x) |
|
return self.linear_2(self.swiglu(h1, h2)) |
|
|
|
|
|
class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: int = 4096, |
|
text_feat_dim: int = 2048, |
|
frequency_embedding_size: int = 256, |
|
norm_eps: float = 1e-5, |
|
timestep_scale: float = 1.0, |
|
use_fused_rms_norm: bool = False |
|
) -> None: |
|
super().__init__() |
|
|
|
self.time_proj = Timesteps( |
|
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=timestep_scale |
|
) |
|
|
|
self.timestep_embedder = TimestepEmbedding( |
|
in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) |
|
) |
|
|
|
if use_fused_rms_norm: |
|
assert FUSEDRMSNORM_AVALIBLE |
|
RMSNorm = FusedRMSNorm |
|
else: |
|
RMSNorm = nn.RMSNorm |
|
|
|
self.caption_embedder = nn.Sequential( |
|
RMSNorm(text_feat_dim, eps=norm_eps), |
|
nn.Linear(text_feat_dim, hidden_size, bias=True), |
|
) |
|
|
|
self._initialize_weights() |
|
|
|
def _initialize_weights(self): |
|
nn.init.trunc_normal_(self.caption_embedder[1].weight, std=0.02) |
|
nn.init.zeros_(self.caption_embedder[1].bias) |
|
|
|
def forward( |
|
self, timestep: torch.Tensor, text_hidden_states: torch.Tensor, dtype: torch.dtype |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
timestep_proj = self.time_proj(timestep).to(dtype=dtype) |
|
time_embed = self.timestep_embedder(timestep_proj) |
|
caption_embed = self.caption_embedder(text_hidden_states) |
|
return time_embed, caption_embed |