Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Union, Callable | |
| import torch | |
| from torch import nn | |
| from common.cache import Cache | |
| from common.distributed.ops import slice_inputs | |
| from . import na | |
| from .embedding import TimeEmbedding | |
| from .modulation import get_ada_layer | |
| from .nablocks import get_nablock | |
| from .normalization import get_norm_layer | |
| from .patch import NaPatchIn, NaPatchOut | |
| # Fake func, no checkpointing is required for inference | |
| def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): | |
| return module(*args, **kwargs) | |
| class NaDiTOutput: | |
| vid_sample: torch.Tensor | |
| class NaDiT(nn.Module): | |
| """ | |
| Native Resolution Diffusion Transformer (NaDiT) | |
| """ | |
| gradient_checkpointing = False | |
| def __init__( | |
| self, | |
| vid_in_channels: int, | |
| vid_out_channels: int, | |
| vid_dim: int, | |
| txt_in_dim: Optional[int], | |
| txt_dim: Optional[int], | |
| emb_dim: int, | |
| heads: int, | |
| head_dim: int, | |
| expand_ratio: int, | |
| norm: Optional[str], | |
| norm_eps: float, | |
| ada: str, | |
| qk_bias: bool, | |
| qk_rope: bool, | |
| qk_norm: Optional[str], | |
| patch_size: Union[int, Tuple[int, int, int]], | |
| num_layers: int, | |
| block_type: Union[str, Tuple[str]], | |
| shared_qkv: bool = False, | |
| shared_mlp: bool = False, | |
| mlp_type: str = "normal", | |
| window: Optional[Tuple] = None, | |
| window_method: Optional[Tuple[str]] = None, | |
| temporal_window_size: int = None, | |
| temporal_shifted: bool = False, | |
| **kwargs, | |
| ): | |
| ada = get_ada_layer(ada) | |
| norm = get_norm_layer(norm) | |
| qk_norm = get_norm_layer(qk_norm) | |
| if isinstance(block_type, str): | |
| block_type = [block_type] * num_layers | |
| elif len(block_type) != num_layers: | |
| raise ValueError("The ``block_type`` list should equal to ``num_layers``.") | |
| super().__init__() | |
| self.vid_in = NaPatchIn( | |
| in_channels=vid_in_channels, | |
| patch_size=patch_size, | |
| dim=vid_dim, | |
| ) | |
| self.txt_in = ( | |
| nn.Linear(txt_in_dim, txt_dim) | |
| if txt_in_dim and txt_in_dim != txt_dim | |
| else nn.Identity() | |
| ) | |
| self.emb_in = TimeEmbedding( | |
| sinusoidal_dim=256, | |
| hidden_dim=max(vid_dim, txt_dim), | |
| output_dim=emb_dim, | |
| ) | |
| if window is None or isinstance(window[0], int): | |
| window = [window] * num_layers | |
| if window_method is None or isinstance(window_method, str): | |
| window_method = [window_method] * num_layers | |
| if temporal_window_size is None or isinstance(temporal_window_size, int): | |
| temporal_window_size = [temporal_window_size] * num_layers | |
| if temporal_shifted is None or isinstance(temporal_shifted, bool): | |
| temporal_shifted = [temporal_shifted] * num_layers | |
| self.blocks = nn.ModuleList( | |
| [ | |
| get_nablock(block_type[i])( | |
| vid_dim=vid_dim, | |
| txt_dim=txt_dim, | |
| emb_dim=emb_dim, | |
| heads=heads, | |
| head_dim=head_dim, | |
| expand_ratio=expand_ratio, | |
| norm=norm, | |
| norm_eps=norm_eps, | |
| ada=ada, | |
| qk_bias=qk_bias, | |
| qk_rope=qk_rope, | |
| qk_norm=qk_norm, | |
| shared_qkv=shared_qkv, | |
| shared_mlp=shared_mlp, | |
| mlp_type=mlp_type, | |
| window=window[i], | |
| window_method=window_method[i], | |
| temporal_window_size=temporal_window_size[i], | |
| temporal_shifted=temporal_shifted[i], | |
| **kwargs, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.vid_out = NaPatchOut( | |
| out_channels=vid_out_channels, | |
| patch_size=patch_size, | |
| dim=vid_dim, | |
| ) | |
| self.need_txt_repeat = block_type[0] in [ | |
| "mmdit_stwin", | |
| "mmdit_stwin_spatial", | |
| "mmdit_stwin_3d_spatial", | |
| ] | |
| def set_gradient_checkpointing(self, enable: bool): | |
| self.gradient_checkpointing = enable | |
| def forward( | |
| self, | |
| vid: torch.FloatTensor, # l c | |
| txt: torch.FloatTensor, # l c | |
| vid_shape: torch.LongTensor, # b 3 | |
| txt_shape: torch.LongTensor, # b 1 | |
| timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b | |
| disable_cache: bool = True, # for test | |
| ): | |
| # Text input. | |
| if txt_shape.size(-1) == 1 and self.need_txt_repeat: | |
| txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) | |
| # slice vid after patching in when using sequence parallelism | |
| txt = slice_inputs(txt, dim=0) | |
| txt = self.txt_in(txt) | |
| # Video input. | |
| # Sequence parallel slicing is done inside patching class. | |
| vid, vid_shape = self.vid_in(vid, vid_shape) | |
| # Embedding input. | |
| emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) | |
| # Body | |
| cache = Cache(disable=disable_cache) | |
| for i, block in enumerate(self.blocks): | |
| vid, txt, vid_shape, txt_shape = gradient_checkpointing( | |
| enabled=(self.gradient_checkpointing and self.training), | |
| module=block, | |
| vid=vid, | |
| txt=txt, | |
| vid_shape=vid_shape, | |
| txt_shape=txt_shape, | |
| emb=emb, | |
| cache=cache, | |
| ) | |
| vid, vid_shape = self.vid_out(vid, vid_shape, cache) | |
| return NaDiTOutput(vid_sample=vid) | |
| class NaDiTUpscaler(nn.Module): | |
| """ | |
| Native Resolution Diffusion Transformer (NaDiT) | |
| """ | |
| gradient_checkpointing = False | |
| def __init__( | |
| self, | |
| vid_in_channels: int, | |
| vid_out_channels: int, | |
| vid_dim: int, | |
| txt_in_dim: Optional[int], | |
| txt_dim: Optional[int], | |
| emb_dim: int, | |
| heads: int, | |
| head_dim: int, | |
| expand_ratio: int, | |
| norm: Optional[str], | |
| norm_eps: float, | |
| ada: str, | |
| qk_bias: bool, | |
| qk_rope: bool, | |
| qk_norm: Optional[str], | |
| patch_size: Union[int, Tuple[int, int, int]], | |
| num_layers: int, | |
| block_type: Union[str, Tuple[str]], | |
| shared_qkv: bool = False, | |
| shared_mlp: bool = False, | |
| mlp_type: str = "normal", | |
| window: Optional[Tuple] = None, | |
| window_method: Optional[Tuple[str]] = None, | |
| temporal_window_size: int = None, | |
| temporal_shifted: bool = False, | |
| **kwargs, | |
| ): | |
| ada = get_ada_layer(ada) | |
| norm = get_norm_layer(norm) | |
| qk_norm = get_norm_layer(qk_norm) | |
| if isinstance(block_type, str): | |
| block_type = [block_type] * num_layers | |
| elif len(block_type) != num_layers: | |
| raise ValueError("The ``block_type`` list should equal to ``num_layers``.") | |
| super().__init__() | |
| self.vid_in = NaPatchIn( | |
| in_channels=vid_in_channels, | |
| patch_size=patch_size, | |
| dim=vid_dim, | |
| ) | |
| self.txt_in = ( | |
| nn.Linear(txt_in_dim, txt_dim) | |
| if txt_in_dim and txt_in_dim != txt_dim | |
| else nn.Identity() | |
| ) | |
| self.emb_in = TimeEmbedding( | |
| sinusoidal_dim=256, | |
| hidden_dim=max(vid_dim, txt_dim), | |
| output_dim=emb_dim, | |
| ) | |
| self.emb_scale = TimeEmbedding( | |
| sinusoidal_dim=256, | |
| hidden_dim=max(vid_dim, txt_dim), | |
| output_dim=emb_dim, | |
| ) | |
| if window is None or isinstance(window[0], int): | |
| window = [window] * num_layers | |
| if window_method is None or isinstance(window_method, str): | |
| window_method = [window_method] * num_layers | |
| if temporal_window_size is None or isinstance(temporal_window_size, int): | |
| temporal_window_size = [temporal_window_size] * num_layers | |
| if temporal_shifted is None or isinstance(temporal_shifted, bool): | |
| temporal_shifted = [temporal_shifted] * num_layers | |
| self.blocks = nn.ModuleList( | |
| [ | |
| get_nablock(block_type[i])( | |
| vid_dim=vid_dim, | |
| txt_dim=txt_dim, | |
| emb_dim=emb_dim, | |
| heads=heads, | |
| head_dim=head_dim, | |
| expand_ratio=expand_ratio, | |
| norm=norm, | |
| norm_eps=norm_eps, | |
| ada=ada, | |
| qk_bias=qk_bias, | |
| qk_rope=qk_rope, | |
| qk_norm=qk_norm, | |
| shared_qkv=shared_qkv, | |
| shared_mlp=shared_mlp, | |
| mlp_type=mlp_type, | |
| window=window[i], | |
| window_method=window_method[i], | |
| temporal_window_size=temporal_window_size[i], | |
| temporal_shifted=temporal_shifted[i], | |
| **kwargs, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.vid_out = NaPatchOut( | |
| out_channels=vid_out_channels, | |
| patch_size=patch_size, | |
| dim=vid_dim, | |
| ) | |
| self.need_txt_repeat = block_type[0] in [ | |
| "mmdit_stwin", | |
| "mmdit_stwin_spatial", | |
| "mmdit_stwin_3d_spatial", | |
| ] | |
| def set_gradient_checkpointing(self, enable: bool): | |
| self.gradient_checkpointing = enable | |
| def forward( | |
| self, | |
| vid: torch.FloatTensor, # l c | |
| txt: torch.FloatTensor, # l c | |
| vid_shape: torch.LongTensor, # b 3 | |
| txt_shape: torch.LongTensor, # b 1 | |
| timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b | |
| downscale: Union[int, float, torch.IntTensor, torch.FloatTensor], # b | |
| disable_cache: bool = False, # for test | |
| ): | |
| # Text input. | |
| if txt_shape.size(-1) == 1 and self.need_txt_repeat: | |
| txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) | |
| # slice vid after patching in when using sequence parallelism | |
| txt = slice_inputs(txt, dim=0) | |
| txt = self.txt_in(txt) | |
| # Video input. | |
| # Sequence parallel slicing is done inside patching class. | |
| vid, vid_shape = self.vid_in(vid, vid_shape) | |
| # Embedding input. | |
| emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) | |
| emb_scale = self.emb_scale(downscale, device=vid.device, dtype=vid.dtype) | |
| emb = emb + emb_scale | |
| # Body | |
| cache = Cache(disable=disable_cache) | |
| for i, block in enumerate(self.blocks): | |
| vid, txt, vid_shape, txt_shape = gradient_checkpointing( | |
| enabled=(self.gradient_checkpointing and self.training), | |
| module=block, | |
| vid=vid, | |
| txt=txt, | |
| vid_shape=vid_shape, | |
| txt_shape=txt_shape, | |
| emb=emb, | |
| cache=cache, | |
| ) | |
| vid, vid_shape = self.vid_out(vid, vid_shape, cache) | |
| return NaDiTOutput(vid_sample=vid) | |