Spaces:
Running
on
Zero
Running
on
Zero
| # some modules/classes are copied and modified from https://github.com/mcmonkey4eva/sd3-ref | |
| # the original code is licensed under the MIT License | |
| # and some module/classes are contributed from KohakuBlueleaf. Thanks for the contribution! | |
| from ast import Tuple | |
| from concurrent.futures import ThreadPoolExecutor | |
| from dataclasses import dataclass | |
| from functools import partial | |
| import math | |
| from types import SimpleNamespace | |
| from typing import Dict, List, Optional, Union | |
| import einops | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.checkpoint import checkpoint | |
| from transformers import CLIPTokenizer, T5TokenizerFast | |
| from library import custom_offloading_utils | |
| from library.device_utils import clean_memory_on_device | |
| from .utils import setup_logging | |
| setup_logging() | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| memory_efficient_attention = None | |
| try: | |
| import xformers | |
| except: | |
| pass | |
| try: | |
| from xformers.ops import memory_efficient_attention | |
| except: | |
| memory_efficient_attention = None | |
| # region mmdit | |
| class SD3Params: | |
| patch_size: int | |
| depth: int | |
| num_patches: int | |
| pos_embed_max_size: int | |
| adm_in_channels: int | |
| qk_norm: Optional[str] | |
| x_block_self_attn_layers: list[int] | |
| context_embedder_in_features: int | |
| context_embedder_out_features: int | |
| model_type: str | |
| def get_2d_sincos_pos_embed( | |
| embed_dim, | |
| grid_size, | |
| scaling_factor=None, | |
| offset=None, | |
| ): | |
| grid_h = np.arange(grid_size, dtype=np.float32) | |
| grid_w = np.arange(grid_size, dtype=np.float32) | |
| grid = np.meshgrid(grid_w, grid_h) # here w goes first | |
| grid = np.stack(grid, axis=0) | |
| if scaling_factor is not None: | |
| grid = grid / scaling_factor | |
| if offset is not None: | |
| grid = grid - offset | |
| grid = grid.reshape([2, 1, grid_size, grid_size]) | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| return pos_embed | |
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): | |
| assert embed_dim % 2 == 0 | |
| # use half of dimensions to encode grid_h | |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) | |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) | |
| emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) | |
| return emb | |
| def get_scaled_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, sample_size=64, base_size=16): | |
| """ | |
| This function is contributed by KohakuBlueleaf. Thanks for the contribution! | |
| Creates scaled 2D sinusoidal positional embeddings that maintain consistent relative positions | |
| when the resolution differs from the training resolution. | |
| Args: | |
| embed_dim (int): Dimension of the positional embedding. | |
| grid_size (int or tuple): Size of the position grid (H, W). If int, assumes square grid. | |
| cls_token (bool): Whether to include class token. Defaults to False. | |
| extra_tokens (int): Number of extra tokens (e.g., cls_token). Defaults to 0. | |
| sample_size (int): Reference resolution (typically training resolution). Defaults to 64. | |
| base_size (int): Base grid size used during training. Defaults to 16. | |
| Returns: | |
| numpy.ndarray: Positional embeddings of shape (H*W, embed_dim) or | |
| (H*W + extra_tokens, embed_dim) if cls_token is True. | |
| """ | |
| # Convert grid_size to tuple if it's an integer | |
| if isinstance(grid_size, int): | |
| grid_size = (grid_size, grid_size) | |
| # Create normalized grid coordinates (0 to 1) | |
| grid_h = np.arange(grid_size[0], dtype=np.float32) / grid_size[0] | |
| grid_w = np.arange(grid_size[1], dtype=np.float32) / grid_size[1] | |
| # Calculate scaling factors for height and width | |
| # This ensures that the central region matches the original resolution's embeddings | |
| scale_h = base_size * grid_size[0] / (sample_size) | |
| scale_w = base_size * grid_size[1] / (sample_size) | |
| # Calculate shift values to center the original resolution's embedding region | |
| # This ensures that the central sample_size x sample_size region has similar | |
| # positional embeddings to the original resolution | |
| shift_h = 1 * scale_h * (grid_size[0] - sample_size) / (2 * grid_size[0]) | |
| shift_w = 1 * scale_w * (grid_size[1] - sample_size) / (2 * grid_size[1]) | |
| # Apply scaling and shifting to create the final grid coordinates | |
| grid_h = grid_h * scale_h - shift_h | |
| grid_w = grid_w * scale_w - shift_w | |
| # Create 2D grid using meshgrid (note: w goes first) | |
| grid = np.meshgrid(grid_w, grid_h) | |
| grid = np.stack(grid, axis=0) | |
| # # Calculate the starting indices for the central region | |
| # # This is used for debugging/visualization of the central region | |
| # st_h = (grid_size[0] - sample_size) // 2 | |
| # st_w = (grid_size[1] - sample_size) // 2 | |
| # print(grid[:, st_h : st_h + sample_size, st_w : st_w + sample_size]) | |
| # Reshape grid for positional embedding calculation | |
| grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) | |
| # Generate the sinusoidal positional embeddings | |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) | |
| # Add zeros for extra tokens (e.g., [CLS] token) if required | |
| if cls_token and extra_tokens > 0: | |
| pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) | |
| return pos_embed | |
| # if __name__ == "__main__": | |
| # # This is what you get when you load SD3.5 state dict | |
| # pos_emb = torch.from_numpy(get_scaled_2d_sincos_pos_embed( | |
| # 1536, [384, 384], sample_size=64, base_size=16 | |
| # )).float().unsqueeze(0) | |
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): | |
| """ | |
| embed_dim: output dimension for each position | |
| pos: a list of positions to be encoded: size (M,) | |
| out: (M, D) | |
| """ | |
| assert embed_dim % 2 == 0 | |
| omega = np.arange(embed_dim // 2, dtype=np.float64) | |
| omega /= embed_dim / 2.0 | |
| omega = 1.0 / 10000**omega # (D/2,) | |
| pos = pos.reshape(-1) # (M,) | |
| out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product | |
| emb_sin = np.sin(out) # (M, D/2) | |
| emb_cos = np.cos(out) # (M, D/2) | |
| emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) | |
| return emb | |
| def get_1d_sincos_pos_embed_from_grid_torch( | |
| embed_dim, | |
| pos, | |
| device=None, | |
| dtype=torch.float32, | |
| ): | |
| omega = torch.arange(embed_dim // 2, device=device, dtype=dtype) | |
| omega *= 2.0 / embed_dim | |
| omega = 1.0 / 10000**omega | |
| out = torch.outer(pos.reshape(-1), omega) | |
| emb = torch.cat([out.sin(), out.cos()], dim=1) | |
| return emb | |
| def get_2d_sincos_pos_embed_torch( | |
| embed_dim, | |
| w, | |
| h, | |
| val_center=7.5, | |
| val_magnitude=7.5, | |
| device=None, | |
| dtype=torch.float32, | |
| ): | |
| small = min(h, w) | |
| val_h = (h / small) * val_magnitude | |
| val_w = (w / small) * val_magnitude | |
| grid_h, grid_w = torch.meshgrid( | |
| torch.linspace(-val_h + val_center, val_h + val_center, h, device=device, dtype=dtype), | |
| torch.linspace(-val_w + val_center, val_w + val_center, w, device=device, dtype=dtype), | |
| indexing="ij", | |
| ) | |
| emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) | |
| emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) | |
| emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) | |
| return emb | |
| def modulate(x, shift, scale): | |
| if shift is None: | |
| shift = torch.zeros_like(scale) | |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| def default(x, default_value): | |
| if x is None: | |
| return default_value | |
| return x | |
| def timestep_embedding(t, dim, max_period=10000): | |
| half = dim // 2 | |
| # freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( | |
| # device=t.device, dtype=t.dtype | |
| # ) | |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| if torch.is_floating_point(t): | |
| embedding = embedding.to(dtype=t.dtype) | |
| return embedding | |
| class PatchEmbed(nn.Module): | |
| def __init__( | |
| self, | |
| img_size=256, | |
| patch_size=4, | |
| in_channels=3, | |
| embed_dim=512, | |
| norm_layer=None, | |
| flatten=True, | |
| bias=True, | |
| strict_img_size=True, | |
| dynamic_img_pad=False, | |
| ): | |
| # dynamic_img_pad and norm is omitted in SD3.5 | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.flatten = flatten | |
| self.strict_img_size = strict_img_size | |
| self.dynamic_img_pad = dynamic_img_pad | |
| if img_size is not None: | |
| self.img_size = img_size | |
| self.grid_size = img_size // patch_size | |
| self.num_patches = self.grid_size**2 | |
| else: | |
| self.img_size = None | |
| self.grid_size = None | |
| self.num_patches = None | |
| self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) | |
| self.norm = nn.Identity() if norm_layer is None else norm_layer(embed_dim) | |
| def forward(self, x): | |
| B, C, H, W = x.shape | |
| if self.dynamic_img_pad: | |
| # Pad input so we won't have partial patch | |
| pad_h = (self.patch_size - H % self.patch_size) % self.patch_size | |
| pad_w = (self.patch_size - W % self.patch_size) % self.patch_size | |
| x = nn.functional.pad(x, (0, pad_w, 0, pad_h), mode="reflect") | |
| x = self.proj(x) | |
| if self.flatten: | |
| x = x.flatten(2).transpose(1, 2) | |
| x = self.norm(x) | |
| return x | |
| # FinalLayer in mmdit.py | |
| class UnPatch(nn.Module): | |
| def __init__(self, hidden_size=512, patch_size=4, out_channels=3): | |
| super().__init__() | |
| self.patch_size = patch_size | |
| self.c = out_channels | |
| # eps is default in mmdit.py | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, patch_size**2 * out_channels) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 2 * hidden_size), | |
| ) | |
| def forward(self, x: torch.Tensor, cmod, H=None, W=None): | |
| b, n, _ = x.shape | |
| p = self.patch_size | |
| c = self.c | |
| if H is None and W is None: | |
| w = h = int(n**0.5) | |
| assert h * w == n | |
| else: | |
| h = H // p if H else n // (W // p) | |
| w = W // p if W else n // h | |
| assert h * w == n | |
| shift, scale = self.adaLN_modulation(cmod).chunk(2, dim=-1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| x = self.linear(x) | |
| x = x.view(b, h, w, p, p, c) | |
| x = x.permute(0, 5, 1, 3, 2, 4).contiguous() | |
| x = x.view(b, c, h * p, w * p) | |
| return x | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| in_features, | |
| hidden_features=None, | |
| out_features=None, | |
| act_layer=lambda: nn.GELU(), | |
| norm_layer=None, | |
| bias=True, | |
| use_conv=False, | |
| ): | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.use_conv = use_conv | |
| layer = partial(nn.Conv1d, kernel_size=1) if use_conv else nn.Linear | |
| self.fc1 = layer(in_features, hidden_features, bias=bias) | |
| self.fc2 = layer(hidden_features, out_features, bias=bias) | |
| self.act = act_layer() | |
| self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.norm(x) | |
| x = self.fc2(x) | |
| return x | |
| class TimestepEmbedding(nn.Module): | |
| def __init__(self, hidden_size, freq_embed_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(freq_embed_size, hidden_size), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size), | |
| ) | |
| self.freq_embed_size = freq_embed_size | |
| def forward(self, t, dtype=None, **kwargs): | |
| t_freq = timestep_embedding(t, self.freq_embed_size).to(dtype) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class Embedder(nn.Module): | |
| def __init__(self, input_dim, hidden_size): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, hidden_size), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size), | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| def rmsnorm(x, eps=1e-6): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) | |
| class RMSNorm(torch.nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| elementwise_affine: bool = False, | |
| eps: float = 1e-6, | |
| device=None, | |
| dtype=None, | |
| ): | |
| """ | |
| Initialize the RMSNorm normalization layer. | |
| Args: | |
| dim (int): The dimension of the input tensor. | |
| eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. | |
| Attributes: | |
| eps (float): A small value added to the denominator for numerical stability. | |
| weight (nn.Parameter): Learnable scaling parameter. | |
| """ | |
| super().__init__() | |
| self.eps = eps | |
| self.learnable_scale = elementwise_affine | |
| if self.learnable_scale: | |
| self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) | |
| else: | |
| self.register_parameter("weight", None) | |
| def forward(self, x): | |
| """ | |
| Forward pass through the RMSNorm layer. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor after applying RMSNorm. | |
| """ | |
| x = rmsnorm(x, eps=self.eps) | |
| if self.learnable_scale: | |
| return x * self.weight.to(device=x.device, dtype=x.dtype) | |
| else: | |
| return x | |
| class SwiGLUFeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| hidden_dim: int, | |
| multiple_of: int, | |
| ffn_dim_multiplier: float = None, | |
| ): | |
| super().__init__() | |
| hidden_dim = int(2 * hidden_dim / 3) | |
| # custom dim factor multiplier | |
| if ffn_dim_multiplier is not None: | |
| hidden_dim = int(ffn_dim_multiplier * hidden_dim) | |
| hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| def forward(self, x): | |
| return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) | |
| # Linears for SelfAttention in mmdit.py | |
| class AttentionLinears(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| num_heads: int = 8, | |
| qkv_bias: bool = False, | |
| pre_only: bool = False, | |
| qk_norm: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
| if not pre_only: | |
| self.proj = nn.Linear(dim, dim) | |
| self.pre_only = pre_only | |
| if qk_norm == "rms": | |
| self.ln_q = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) | |
| self.ln_k = RMSNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) | |
| elif qk_norm == "ln": | |
| self.ln_q = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) | |
| self.ln_k = nn.LayerNorm(self.head_dim, elementwise_affine=True, eps=1.0e-6) | |
| elif qk_norm is None: | |
| self.ln_q = nn.Identity() | |
| self.ln_k = nn.Identity() | |
| else: | |
| raise ValueError(qk_norm) | |
| def pre_attention(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| output: | |
| q, k, v: [B, L, D] | |
| """ | |
| B, L, C = x.shape | |
| qkv: torch.Tensor = self.qkv(x) | |
| q, k, v = qkv.reshape(B, L, -1, self.head_dim).chunk(3, dim=2) | |
| q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) | |
| k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) | |
| return (q, k, v) | |
| def post_attention(self, x: torch.Tensor) -> torch.Tensor: | |
| assert not self.pre_only | |
| x = self.proj(x) | |
| return x | |
| MEMORY_LAYOUTS = { | |
| "torch": ( | |
| lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), | |
| lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), | |
| lambda x: (1, x, 1, 1), | |
| ), | |
| "xformers": ( | |
| lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim), | |
| lambda x: x.reshape(x.shape[0], x.shape[1], -1), | |
| lambda x: (1, 1, x, 1), | |
| ), | |
| "math": ( | |
| lambda x, head_dim: x.reshape(x.shape[0], x.shape[1], -1, head_dim).transpose(1, 2), | |
| lambda x: x.transpose(1, 2).reshape(x.shape[0], x.shape[2], -1), | |
| lambda x: (1, x, 1, 1), | |
| ), | |
| } | |
| # ATTN_FUNCTION = { | |
| # "torch": F.scaled_dot_product_attention, | |
| # "xformers": memory_efficient_attention, | |
| # } | |
| def vanilla_attention(q, k, v, mask, scale=None): | |
| if scale is None: | |
| scale = math.sqrt(q.size(-1)) | |
| scores = torch.bmm(q, k.transpose(-1, -2)) / scale | |
| if mask is not None: | |
| mask = einops.rearrange(mask, "b ... -> b (...)") | |
| max_neg_value = -torch.finfo(scores.dtype).max | |
| mask = einops.repeat(mask, "b j -> (b h) j", h=q.size(-3)) | |
| scores = scores.masked_fill(~mask, max_neg_value) | |
| p_attn = F.softmax(scores, dim=-1) | |
| return torch.bmm(p_attn, v) | |
| def attention(q, k, v, head_dim, mask=None, scale=None, mode="xformers"): | |
| """ | |
| q, k, v: [B, L, D] | |
| """ | |
| pre_attn_layout = MEMORY_LAYOUTS[mode][0] | |
| post_attn_layout = MEMORY_LAYOUTS[mode][1] | |
| q = pre_attn_layout(q, head_dim) | |
| k = pre_attn_layout(k, head_dim) | |
| v = pre_attn_layout(v, head_dim) | |
| # scores = ATTN_FUNCTION[mode](q, k.to(q), v.to(q), mask, scale=scale) | |
| if mode == "torch": | |
| assert scale is None | |
| scores = F.scaled_dot_product_attention(q, k.to(q), v.to(q), mask) # , scale=scale) | |
| elif mode == "xformers": | |
| scores = memory_efficient_attention(q, k.to(q), v.to(q), mask, scale=scale) | |
| else: | |
| scores = vanilla_attention(q, k.to(q), v.to(q), mask, scale=scale) | |
| scores = post_attn_layout(scores) | |
| return scores | |
| # DismantledBlock in mmdit.py | |
| class SingleDiTBlock(nn.Module): | |
| """ | |
| A DiT block with gated adaptive layer norm (adaLN) conditioning. | |
| """ | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_heads: int, | |
| mlp_ratio: float = 4.0, | |
| attn_mode: str = "xformers", | |
| qkv_bias: bool = False, | |
| pre_only: bool = False, | |
| rmsnorm: bool = False, | |
| scale_mod_only: bool = False, | |
| swiglu: bool = False, | |
| qk_norm: Optional[str] = None, | |
| x_block_self_attn: bool = False, | |
| **block_kwargs, | |
| ): | |
| super().__init__() | |
| assert attn_mode in MEMORY_LAYOUTS | |
| self.attn_mode = attn_mode | |
| if not rmsnorm: | |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.attn = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=pre_only, qk_norm=qk_norm) | |
| self.x_block_self_attn = x_block_self_attn | |
| if self.x_block_self_attn: | |
| assert not pre_only | |
| assert not scale_mod_only | |
| self.attn2 = AttentionLinears(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, pre_only=False, qk_norm=qk_norm) | |
| if not pre_only: | |
| if not rmsnorm: | |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| else: | |
| self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
| if not pre_only: | |
| if not swiglu: | |
| self.mlp = MLP( | |
| in_features=hidden_size, | |
| hidden_features=mlp_hidden_dim, | |
| act_layer=lambda: nn.GELU(approximate="tanh"), | |
| ) | |
| else: | |
| self.mlp = SwiGLUFeedForward( | |
| dim=hidden_size, | |
| hidden_dim=mlp_hidden_dim, | |
| multiple_of=256, | |
| ) | |
| self.scale_mod_only = scale_mod_only | |
| if self.x_block_self_attn: | |
| n_mods = 9 | |
| elif not scale_mod_only: | |
| n_mods = 6 if not pre_only else 2 | |
| else: | |
| n_mods = 4 if not pre_only else 1 | |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, n_mods * hidden_size)) | |
| self.pre_only = pre_only | |
| def pre_attention(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
| if not self.pre_only: | |
| if not self.scale_mod_only: | |
| (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(6, dim=-1) | |
| else: | |
| shift_msa = None | |
| shift_mlp = None | |
| (scale_msa, gate_msa, scale_mlp, gate_mlp) = self.adaLN_modulation(c).chunk(4, dim=-1) | |
| qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) | |
| return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) | |
| else: | |
| if not self.scale_mod_only: | |
| (shift_msa, scale_msa) = self.adaLN_modulation(c).chunk(2, dim=-1) | |
| else: | |
| shift_msa = None | |
| scale_msa = self.adaLN_modulation(c) | |
| qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) | |
| return qkv, None | |
| def pre_attention_x(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
| assert self.x_block_self_attn | |
| (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2) = self.adaLN_modulation( | |
| c | |
| ).chunk(9, dim=1) | |
| x_norm = self.norm1(x) | |
| qkv = self.attn.pre_attention(modulate(x_norm, shift_msa, scale_msa)) | |
| qkv2 = self.attn2.pre_attention(modulate(x_norm, shift_msa2, scale_msa2)) | |
| return qkv, qkv2, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2) | |
| def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): | |
| assert not self.pre_only | |
| x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) | |
| x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| return x | |
| def post_attention_x(self, attn, attn2, x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2, attn1_dropout: float = 0.0): | |
| assert not self.pre_only | |
| if attn1_dropout > 0.0: | |
| # Use torch.bernoulli to implement dropout, only dropout the batch dimension | |
| attn1_dropout = torch.bernoulli(torch.full((attn.size(0), 1, 1), 1 - attn1_dropout, device=attn.device)) | |
| attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) * attn1_dropout | |
| else: | |
| attn_ = gate_msa.unsqueeze(1) * self.attn.post_attention(attn) | |
| x = x + attn_ | |
| attn2_ = gate_msa2.unsqueeze(1) * self.attn2.post_attention(attn2) | |
| x = x + attn2_ | |
| mlp_ = gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) | |
| x = x + mlp_ | |
| return x | |
| # JointBlock + block_mixing in mmdit.py | |
| class MMDiTBlock(nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| pre_only = kwargs.pop("pre_only") | |
| x_block_self_attn = kwargs.pop("x_block_self_attn") | |
| self.context_block = SingleDiTBlock(*args, pre_only=pre_only, **kwargs) | |
| self.x_block = SingleDiTBlock(*args, pre_only=False, x_block_self_attn=x_block_self_attn, **kwargs) | |
| self.head_dim = self.x_block.attn.head_dim | |
| self.mode = self.x_block.attn_mode | |
| self.gradient_checkpointing = False | |
| def enable_gradient_checkpointing(self): | |
| self.gradient_checkpointing = True | |
| def _forward(self, context, x, c): | |
| ctx_qkv, ctx_intermediate = self.context_block.pre_attention(context, c) | |
| if self.x_block.x_block_self_attn: | |
| x_qkv, x_qkv2, x_intermediates = self.x_block.pre_attention_x(x, c) | |
| else: | |
| x_qkv, x_intermediates = self.x_block.pre_attention(x, c) | |
| ctx_len = ctx_qkv[0].size(1) | |
| q = torch.concat((ctx_qkv[0], x_qkv[0]), dim=1) | |
| k = torch.concat((ctx_qkv[1], x_qkv[1]), dim=1) | |
| v = torch.concat((ctx_qkv[2], x_qkv[2]), dim=1) | |
| attn = attention(q, k, v, head_dim=self.head_dim, mode=self.mode) | |
| ctx_attn_out = attn[:, :ctx_len] | |
| x_attn_out = attn[:, ctx_len:] | |
| if self.x_block.x_block_self_attn: | |
| x_q2, x_k2, x_v2 = x_qkv2 | |
| attn2 = attention(x_q2, x_k2, x_v2, self.x_block.attn2.num_heads, mode=self.mode) | |
| x = self.x_block.post_attention_x(x_attn_out, attn2, *x_intermediates) | |
| else: | |
| x = self.x_block.post_attention(x_attn_out, *x_intermediates) | |
| if not self.context_block.pre_only: | |
| context = self.context_block.post_attention(ctx_attn_out, *ctx_intermediate) | |
| else: | |
| context = None | |
| return context, x | |
| def forward(self, *args, **kwargs): | |
| if self.training and self.gradient_checkpointing: | |
| return checkpoint(self._forward, *args, use_reentrant=False, **kwargs) | |
| else: | |
| return self._forward(*args, **kwargs) | |
| class MMDiT(nn.Module): | |
| """ | |
| Diffusion model with a Transformer backbone. | |
| """ | |
| # prepare pos_embed for latent size * 2 | |
| POS_EMBED_MAX_RATIO = 1.5 | |
| def __init__( | |
| self, | |
| input_size: int = 32, | |
| patch_size: int = 2, | |
| in_channels: int = 4, | |
| depth: int = 28, | |
| # hidden_size: Optional[int] = None, | |
| # num_heads: Optional[int] = None, | |
| mlp_ratio: float = 4.0, | |
| learn_sigma: bool = False, | |
| adm_in_channels: Optional[int] = None, | |
| context_embedder_in_features: Optional[int] = None, | |
| context_embedder_out_features: Optional[int] = None, | |
| use_checkpoint: bool = False, | |
| register_length: int = 0, | |
| attn_mode: str = "torch", | |
| rmsnorm: bool = False, | |
| scale_mod_only: bool = False, | |
| swiglu: bool = False, | |
| out_channels: Optional[int] = None, | |
| pos_embed_scaling_factor: Optional[float] = None, | |
| pos_embed_offset: Optional[float] = None, | |
| pos_embed_max_size: Optional[int] = None, | |
| num_patches=None, | |
| qk_norm: Optional[str] = None, | |
| x_block_self_attn_layers: Optional[list[int]] = [], | |
| qkv_bias: bool = True, | |
| pos_emb_random_crop_rate: float = 0.0, | |
| use_scaled_pos_embed: bool = False, | |
| pos_embed_latent_sizes: Optional[list[int]] = None, | |
| model_type: str = "sd3m", | |
| ): | |
| super().__init__() | |
| self._model_type = model_type | |
| self.learn_sigma = learn_sigma | |
| self.in_channels = in_channels | |
| default_out_channels = in_channels * 2 if learn_sigma else in_channels | |
| self.out_channels = default(out_channels, default_out_channels) | |
| self.patch_size = patch_size | |
| self.pos_embed_scaling_factor = pos_embed_scaling_factor | |
| self.pos_embed_offset = pos_embed_offset | |
| self.pos_embed_max_size = pos_embed_max_size | |
| self.x_block_self_attn_layers = x_block_self_attn_layers | |
| self.pos_emb_random_crop_rate = pos_emb_random_crop_rate | |
| self.gradient_checkpointing = use_checkpoint | |
| # hidden_size = default(hidden_size, 64 * depth) | |
| # num_heads = default(num_heads, hidden_size // 64) | |
| # apply magic --> this defines a head_size of 64 | |
| self.hidden_size = 64 * depth | |
| num_heads = depth | |
| self.num_heads = num_heads | |
| self.enable_scaled_pos_embed(use_scaled_pos_embed, pos_embed_latent_sizes) | |
| self.x_embedder = PatchEmbed( | |
| input_size, | |
| patch_size, | |
| in_channels, | |
| self.hidden_size, | |
| bias=True, | |
| strict_img_size=self.pos_embed_max_size is None, | |
| ) | |
| self.t_embedder = TimestepEmbedding(self.hidden_size) | |
| self.y_embedder = None | |
| if adm_in_channels is not None: | |
| assert isinstance(adm_in_channels, int) | |
| self.y_embedder = Embedder(adm_in_channels, self.hidden_size) | |
| if context_embedder_in_features is not None: | |
| self.context_embedder = nn.Linear(context_embedder_in_features, context_embedder_out_features) | |
| else: | |
| self.context_embedder = nn.Identity() | |
| self.register_length = register_length | |
| if self.register_length > 0: | |
| self.register = nn.Parameter(torch.randn(1, register_length, self.hidden_size)) | |
| # num_patches = self.x_embedder.num_patches | |
| # Will use fixed sin-cos embedding: | |
| # just use a buffer already | |
| if num_patches is not None: | |
| self.register_buffer( | |
| "pos_embed", | |
| torch.empty(1, num_patches, self.hidden_size), | |
| ) | |
| else: | |
| self.pos_embed = None | |
| self.use_checkpoint = use_checkpoint | |
| self.joint_blocks = nn.ModuleList( | |
| [ | |
| MMDiTBlock( | |
| self.hidden_size, | |
| num_heads, | |
| mlp_ratio=mlp_ratio, | |
| attn_mode=attn_mode, | |
| qkv_bias=qkv_bias, | |
| pre_only=i == depth - 1, | |
| rmsnorm=rmsnorm, | |
| scale_mod_only=scale_mod_only, | |
| swiglu=swiglu, | |
| qk_norm=qk_norm, | |
| x_block_self_attn=(i in self.x_block_self_attn_layers), | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| for block in self.joint_blocks: | |
| block.gradient_checkpointing = use_checkpoint | |
| self.final_layer = UnPatch(self.hidden_size, patch_size, self.out_channels) | |
| # self.initialize_weights() | |
| self.blocks_to_swap = None | |
| self.offloader = None | |
| self.num_blocks = len(self.joint_blocks) | |
| def enable_scaled_pos_embed(self, use_scaled_pos_embed: bool, latent_sizes: Optional[list[int]]): | |
| self.use_scaled_pos_embed = use_scaled_pos_embed | |
| if self.use_scaled_pos_embed: | |
| # remove pos_embed to free up memory up to 0.4 GB | |
| self.pos_embed = None | |
| # remove duplicates and sort latent sizes in ascending order | |
| latent_sizes = list(set(latent_sizes)) | |
| latent_sizes = sorted(latent_sizes) | |
| patched_sizes = [latent_size // self.patch_size for latent_size in latent_sizes] | |
| # calculate value range for each latent area: this is used to determine the pos_emb size from the latent shape | |
| max_areas = [] | |
| for i in range(1, len(patched_sizes)): | |
| prev_area = patched_sizes[i - 1] ** 2 | |
| area = patched_sizes[i] ** 2 | |
| max_areas.append((prev_area + area) // 2) | |
| # area of the last latent size, if the latent size exceeds this, error will be raised | |
| max_areas.append(int((patched_sizes[-1] * MMDiT.POS_EMBED_MAX_RATIO) ** 2)) | |
| # print("max_areas", max_areas) | |
| self.resolution_area_to_latent_size = [(area, latent_size) for area, latent_size in zip(max_areas, patched_sizes)] | |
| self.resolution_pos_embeds = {} | |
| for patched_size in patched_sizes: | |
| grid_size = int(patched_size * MMDiT.POS_EMBED_MAX_RATIO) | |
| pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, grid_size, sample_size=patched_size) | |
| pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) | |
| self.resolution_pos_embeds[patched_size] = pos_embed | |
| # print(f"pos_embed for {patched_size}x{patched_size} latent size: {pos_embed.shape}") | |
| else: | |
| self.resolution_area_to_latent_size = None | |
| self.resolution_pos_embeds = None | |
| def model_type(self): | |
| return self._model_type | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| def enable_gradient_checkpointing(self): | |
| self.gradient_checkpointing = True | |
| for block in self.joint_blocks: | |
| block.enable_gradient_checkpointing() | |
| def disable_gradient_checkpointing(self): | |
| self.gradient_checkpointing = False | |
| for block in self.joint_blocks: | |
| block.disable_gradient_checkpointing() | |
| def initialize_weights(self): | |
| # TODO: Init context_embedder? | |
| # Initialize transformer layers: | |
| def _basic_init(module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.xavier_uniform_(module.weight) | |
| if module.bias is not None: | |
| nn.init.constant_(module.bias, 0) | |
| self.apply(_basic_init) | |
| # Initialize (and freeze) pos_embed by sin-cos embedding | |
| if self.pos_embed is not None: | |
| pos_embed = get_2d_sincos_pos_embed( | |
| self.pos_embed.shape[-1], | |
| int(self.pos_embed.shape[-2] ** 0.5), | |
| scaling_factor=self.pos_embed_scaling_factor, | |
| ) | |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) | |
| # Initialize patch_embed like nn.Linear (instead of nn.Conv2d) | |
| w = self.x_embedder.proj.weight.data | |
| nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
| nn.init.constant_(self.x_embedder.proj.bias, 0) | |
| if getattr(self, "y_embedder", None) is not None: | |
| nn.init.normal_(self.y_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.y_embedder.mlp[2].weight, std=0.02) | |
| # Initialize timestep embedding MLP: | |
| nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
| nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
| # Zero-out adaLN modulation layers in DiT blocks: | |
| for block in self.joint_blocks: | |
| nn.init.constant_(block.x_block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.x_block.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(block.context_block.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(block.context_block.adaLN_modulation[-1].bias, 0) | |
| # Zero-out output layers: | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
| nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
| nn.init.constant_(self.final_layer.linear.weight, 0) | |
| nn.init.constant_(self.final_layer.linear.bias, 0) | |
| def set_pos_emb_random_crop_rate(self, rate: float): | |
| self.pos_emb_random_crop_rate = rate | |
| def cropped_pos_embed(self, h, w, device=None, random_crop: bool = False): | |
| p = self.x_embedder.patch_size | |
| # patched size | |
| h = (h + 1) // p | |
| w = (w + 1) // p | |
| if self.pos_embed is None: # should not happen | |
| return get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=device) | |
| assert self.pos_embed_max_size is not None | |
| assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) | |
| assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) | |
| if not random_crop: | |
| top = (self.pos_embed_max_size - h) // 2 | |
| left = (self.pos_embed_max_size - w) // 2 | |
| else: | |
| top = torch.randint(0, self.pos_embed_max_size - h + 1, (1,)).item() | |
| left = torch.randint(0, self.pos_embed_max_size - w + 1, (1,)).item() | |
| spatial_pos_embed = self.pos_embed.reshape( | |
| 1, | |
| self.pos_embed_max_size, | |
| self.pos_embed_max_size, | |
| self.pos_embed.shape[-1], | |
| ) | |
| spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] | |
| spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) | |
| return spatial_pos_embed | |
| def cropped_scaled_pos_embed(self, h, w, device=None, dtype=None, random_crop: bool = False): | |
| p = self.x_embedder.patch_size | |
| # patched size | |
| h = (h + 1) // p | |
| w = (w + 1) // p | |
| # select pos_embed size based on area | |
| area = h * w | |
| patched_size = None | |
| for area_, patched_size_ in self.resolution_area_to_latent_size: | |
| if area <= area_: | |
| patched_size = patched_size_ | |
| break | |
| if patched_size is None: | |
| raise ValueError(f"Area {area} is too large for the given latent sizes {self.resolution_area_to_latent_size}.") | |
| pos_embed = self.resolution_pos_embeds[patched_size] | |
| pos_embed_size = round(math.sqrt(pos_embed.shape[1])) | |
| if h > pos_embed_size or w > pos_embed_size: | |
| # # fallback to normal pos_embed | |
| # return self.cropped_pos_embed(h * p, w * p, device=device, random_crop=random_crop) | |
| # extend pos_embed size | |
| logger.warning( | |
| f"Using normal pos_embed for size {h}x{w} as it exceeds the scaled pos_embed size {pos_embed_size}. Image is too tall or wide." | |
| ) | |
| pos_embed_size = max(h, w) | |
| pos_embed = get_scaled_2d_sincos_pos_embed(self.hidden_size, pos_embed_size, sample_size=patched_size) | |
| pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0) | |
| self.resolution_pos_embeds[patched_size] = pos_embed | |
| logger.info(f"Updated pos_embed for size {pos_embed_size}x{pos_embed_size}") | |
| if not random_crop: | |
| top = (pos_embed_size - h) // 2 | |
| left = (pos_embed_size - w) // 2 | |
| else: | |
| top = torch.randint(0, pos_embed_size - h + 1, (1,)).item() | |
| left = torch.randint(0, pos_embed_size - w + 1, (1,)).item() | |
| if pos_embed.device != device: | |
| pos_embed = pos_embed.to(device) | |
| # which is better to update device, or transfer every time to device? -> 64x64 emb is 96*96*1536*4=56MB. It's okay to update device. | |
| self.resolution_pos_embeds[patched_size] = pos_embed # update device | |
| if pos_embed.dtype != dtype: | |
| pos_embed = pos_embed.to(dtype) | |
| self.resolution_pos_embeds[patched_size] = pos_embed # update dtype | |
| spatial_pos_embed = pos_embed.reshape(1, pos_embed_size, pos_embed_size, pos_embed.shape[-1]) | |
| spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] | |
| spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) | |
| # print( | |
| # f"patched size: {h}x{w}, pos_embed size: {pos_embed_size}, pos_embed shape: {pos_embed.shape}, top: {top}, left: {left}" | |
| # ) | |
| return spatial_pos_embed | |
| def enable_block_swap(self, num_blocks: int, device: torch.device): | |
| self.blocks_to_swap = num_blocks | |
| assert ( | |
| self.blocks_to_swap <= self.num_blocks - 2 | |
| ), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks." | |
| self.offloader = custom_offloading_utils.ModelOffloader( | |
| self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True | |
| ) | |
| print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.") | |
| def move_to_device_except_swap_blocks(self, device: torch.device): | |
| # assume model is on cpu. do not move blocks to device to reduce temporary memory usage | |
| if self.blocks_to_swap: | |
| save_blocks = self.joint_blocks | |
| self.joint_blocks = None | |
| self.to(device) | |
| if self.blocks_to_swap: | |
| self.joint_blocks = save_blocks | |
| def prepare_block_swap_before_forward(self): | |
| if self.blocks_to_swap is None or self.blocks_to_swap == 0: | |
| return | |
| self.offloader.prepare_block_devices_before_forward(self.joint_blocks) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| t: torch.Tensor, | |
| y: Optional[torch.Tensor] = None, | |
| context: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass of DiT. | |
| x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) | |
| t: (N,) tensor of diffusion timesteps | |
| y: (N, D) tensor of class labels | |
| """ | |
| pos_emb_random_crop = ( | |
| False if self.pos_emb_random_crop_rate == 0.0 else torch.rand(1).item() < self.pos_emb_random_crop_rate | |
| ) | |
| B, C, H, W = x.shape | |
| # x = self.x_embedder(x) + self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) | |
| if not self.use_scaled_pos_embed: | |
| pos_embed = self.cropped_pos_embed(H, W, device=x.device, random_crop=pos_emb_random_crop).to(dtype=x.dtype) | |
| else: | |
| # print(f"Using scaled pos_embed for size {H}x{W}") | |
| pos_embed = self.cropped_scaled_pos_embed(H, W, device=x.device, dtype=x.dtype, random_crop=pos_emb_random_crop) | |
| x = self.x_embedder(x) + pos_embed | |
| del pos_embed | |
| c = self.t_embedder(t, dtype=x.dtype) # (N, D) | |
| if y is not None and self.y_embedder is not None: | |
| y = self.y_embedder(y) # (N, D) | |
| c = c + y # (N, D) | |
| if context is not None: | |
| context = self.context_embedder(context) | |
| if self.register_length > 0: | |
| context = torch.cat( | |
| (einops.repeat(self.register, "1 ... -> b ...", b=x.shape[0]), default(context, torch.Tensor([]).type_as(x))), 1 | |
| ) | |
| if not self.blocks_to_swap: | |
| for block in self.joint_blocks: | |
| context, x = block(context, x, c) | |
| else: | |
| for block_idx, block in enumerate(self.joint_blocks): | |
| self.offloader.wait_for_block(block_idx) | |
| context, x = block(context, x, c) | |
| self.offloader.submit_move_blocks(self.joint_blocks, block_idx) | |
| x = self.final_layer(x, c, H, W) # Our final layer combined UnPatchify | |
| return x[:, :, :H, :W] | |
| def create_sd3_mmdit(params: SD3Params, attn_mode: str = "torch") -> MMDiT: | |
| mmdit = MMDiT( | |
| input_size=None, | |
| pos_embed_max_size=params.pos_embed_max_size, | |
| patch_size=params.patch_size, | |
| in_channels=16, | |
| adm_in_channels=params.adm_in_channels, | |
| context_embedder_in_features=params.context_embedder_in_features, | |
| context_embedder_out_features=params.context_embedder_out_features, | |
| depth=params.depth, | |
| mlp_ratio=4, | |
| qk_norm=params.qk_norm, | |
| x_block_self_attn_layers=params.x_block_self_attn_layers, | |
| num_patches=params.num_patches, | |
| attn_mode=attn_mode, | |
| model_type=params.model_type, | |
| ) | |
| return mmdit | |
| # endregion | |
| # region VAE | |
| VAE_SCALE_FACTOR = 1.5305 | |
| VAE_SHIFT_FACTOR = 0.0609 | |
| def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): | |
| return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) | |
| class ResnetBlock(torch.nn.Module): | |
| def __init__(self, *, in_channels, out_channels=None, dtype=torch.float32, device=None): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| out_channels = in_channels if out_channels is None else out_channels | |
| self.out_channels = out_channels | |
| self.norm1 = Normalize(in_channels, dtype=dtype, device=device) | |
| self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| self.norm2 = Normalize(out_channels, dtype=dtype, device=device) | |
| self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| if self.in_channels != self.out_channels: | |
| self.nin_shortcut = torch.nn.Conv2d( | |
| in_channels, out_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device | |
| ) | |
| else: | |
| self.nin_shortcut = None | |
| self.swish = torch.nn.SiLU(inplace=True) | |
| def forward(self, x): | |
| hidden = x | |
| hidden = self.norm1(hidden) | |
| hidden = self.swish(hidden) | |
| hidden = self.conv1(hidden) | |
| hidden = self.norm2(hidden) | |
| hidden = self.swish(hidden) | |
| hidden = self.conv2(hidden) | |
| if self.in_channels != self.out_channels: | |
| x = self.nin_shortcut(x) | |
| return x + hidden | |
| class AttnBlock(torch.nn.Module): | |
| def __init__(self, in_channels, dtype=torch.float32, device=None): | |
| super().__init__() | |
| self.norm = Normalize(in_channels, dtype=dtype, device=device) | |
| self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) | |
| self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) | |
| self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) | |
| self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, dtype=dtype, device=device) | |
| def forward(self, x): | |
| hidden = self.norm(x) | |
| q = self.q(hidden) | |
| k = self.k(hidden) | |
| v = self.v(hidden) | |
| b, c, h, w = q.shape | |
| q, k, v = map(lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)) | |
| hidden = torch.nn.functional.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default | |
| hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) | |
| hidden = self.proj_out(hidden) | |
| return x + hidden | |
| class Downsample(torch.nn.Module): | |
| def __init__(self, in_channels, dtype=torch.float32, device=None): | |
| super().__init__() | |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0, dtype=dtype, device=device) | |
| def forward(self, x): | |
| pad = (0, 1, 0, 1) | |
| x = torch.nn.functional.pad(x, pad, mode="constant", value=0) | |
| x = self.conv(x) | |
| return x | |
| class Upsample(torch.nn.Module): | |
| def __init__(self, in_channels, dtype=torch.float32, device=None): | |
| super().__init__() | |
| self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| def forward(self, x): | |
| org_dtype = x.dtype | |
| if x.dtype == torch.bfloat16: | |
| x = x.to(torch.float32) | |
| x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") | |
| if x.dtype != org_dtype: | |
| x = x.to(org_dtype) | |
| x = self.conv(x) | |
| return x | |
| class VAEEncoder(torch.nn.Module): | |
| def __init__( | |
| self, ch=128, ch_mult=(1, 2, 4, 4), num_res_blocks=2, in_channels=3, z_channels=16, dtype=torch.float32, device=None | |
| ): | |
| super().__init__() | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| # downsampling | |
| self.conv_in = torch.nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| in_ch_mult = (1,) + tuple(ch_mult) | |
| self.in_ch_mult = in_ch_mult | |
| self.down = torch.nn.ModuleList() | |
| for i_level in range(self.num_resolutions): | |
| block = torch.nn.ModuleList() | |
| attn = torch.nn.ModuleList() | |
| block_in = ch * in_ch_mult[i_level] | |
| block_out = ch * ch_mult[i_level] | |
| for i_block in range(num_res_blocks): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) | |
| block_in = block_out | |
| down = torch.nn.Module() | |
| down.block = block | |
| down.attn = attn | |
| if i_level != self.num_resolutions - 1: | |
| down.downsample = Downsample(block_in, dtype=dtype, device=device) | |
| self.down.append(down) | |
| # middle | |
| self.mid = torch.nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) | |
| self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) | |
| # end | |
| self.norm_out = Normalize(block_in, dtype=dtype, device=device) | |
| self.conv_out = torch.nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| self.swish = torch.nn.SiLU(inplace=True) | |
| def forward(self, x): | |
| # downsampling | |
| hs = [self.conv_in(x)] | |
| for i_level in range(self.num_resolutions): | |
| for i_block in range(self.num_res_blocks): | |
| h = self.down[i_level].block[i_block](hs[-1]) | |
| hs.append(h) | |
| if i_level != self.num_resolutions - 1: | |
| hs.append(self.down[i_level].downsample(hs[-1])) | |
| # middle | |
| h = hs[-1] | |
| h = self.mid.block_1(h) | |
| h = self.mid.attn_1(h) | |
| h = self.mid.block_2(h) | |
| # end | |
| h = self.norm_out(h) | |
| h = self.swish(h) | |
| h = self.conv_out(h) | |
| return h | |
| class VAEDecoder(torch.nn.Module): | |
| def __init__( | |
| self, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=(1, 2, 4, 4), | |
| num_res_blocks=2, | |
| resolution=256, | |
| z_channels=16, | |
| dtype=torch.float32, | |
| device=None, | |
| ): | |
| super().__init__() | |
| self.num_resolutions = len(ch_mult) | |
| self.num_res_blocks = num_res_blocks | |
| block_in = ch * ch_mult[self.num_resolutions - 1] | |
| curr_res = resolution // 2 ** (self.num_resolutions - 1) | |
| # z to block_in | |
| self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| # middle | |
| self.mid = torch.nn.Module() | |
| self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) | |
| self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) | |
| self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dtype=dtype, device=device) | |
| # upsampling | |
| self.up = torch.nn.ModuleList() | |
| for i_level in reversed(range(self.num_resolutions)): | |
| block = torch.nn.ModuleList() | |
| block_out = ch * ch_mult[i_level] | |
| for i_block in range(self.num_res_blocks + 1): | |
| block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dtype=dtype, device=device)) | |
| block_in = block_out | |
| up = torch.nn.Module() | |
| up.block = block | |
| if i_level != 0: | |
| up.upsample = Upsample(block_in, dtype=dtype, device=device) | |
| curr_res = curr_res * 2 | |
| self.up.insert(0, up) # prepend to get consistent order | |
| # end | |
| self.norm_out = Normalize(block_in, dtype=dtype, device=device) | |
| self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1, dtype=dtype, device=device) | |
| self.swish = torch.nn.SiLU(inplace=True) | |
| def forward(self, z): | |
| # z to block_in | |
| hidden = self.conv_in(z) | |
| # middle | |
| hidden = self.mid.block_1(hidden) | |
| hidden = self.mid.attn_1(hidden) | |
| hidden = self.mid.block_2(hidden) | |
| # upsampling | |
| for i_level in reversed(range(self.num_resolutions)): | |
| for i_block in range(self.num_res_blocks + 1): | |
| hidden = self.up[i_level].block[i_block](hidden) | |
| if i_level != 0: | |
| hidden = self.up[i_level].upsample(hidden) | |
| # end | |
| hidden = self.norm_out(hidden) | |
| hidden = self.swish(hidden) | |
| hidden = self.conv_out(hidden) | |
| return hidden | |
| class SDVAE(torch.nn.Module): | |
| def __init__(self, dtype=torch.float32, device=None): | |
| super().__init__() | |
| self.encoder = VAEEncoder(dtype=dtype, device=device) | |
| self.decoder = VAEDecoder(dtype=dtype, device=device) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def dtype(self): | |
| return next(self.parameters()).dtype | |
| # @torch.autocast("cuda", dtype=torch.float16) | |
| def decode(self, latent): | |
| return self.decoder(latent) | |
| # @torch.autocast("cuda", dtype=torch.float16) | |
| def encode(self, image): | |
| hidden = self.encoder(image) | |
| mean, logvar = torch.chunk(hidden, 2, dim=1) | |
| logvar = torch.clamp(logvar, -30.0, 20.0) | |
| std = torch.exp(0.5 * logvar) | |
| return mean + std * torch.randn_like(mean) | |
| def process_in(latent): | |
| return (latent - VAE_SHIFT_FACTOR) * VAE_SCALE_FACTOR | |
| def process_out(latent): | |
| return (latent / VAE_SCALE_FACTOR) + VAE_SHIFT_FACTOR | |
| # endregion | |