Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from typing import Tuple, Union | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from torch.nn.modules.utils import _triple | |
| from common.cache import Cache | |
| from common.distributed.ops import gather_outputs, slice_inputs | |
| from . import na | |
| class PatchIn(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| patch_size: Union[int, Tuple[int, int, int]], | |
| dim: int, | |
| ): | |
| super().__init__() | |
| t, h, w = _triple(patch_size) | |
| self.patch_size = t, h, w | |
| self.proj = nn.Linear(in_channels * t * h * w, dim) | |
| def forward( | |
| self, | |
| vid: torch.Tensor, | |
| ) -> torch.Tensor: | |
| t, h, w = self.patch_size | |
| vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) | |
| vid = self.proj(vid) | |
| return vid | |
| class PatchOut(nn.Module): | |
| def __init__( | |
| self, | |
| out_channels: int, | |
| patch_size: Union[int, Tuple[int, int, int]], | |
| dim: int, | |
| ): | |
| super().__init__() | |
| t, h, w = _triple(patch_size) | |
| self.patch_size = t, h, w | |
| self.proj = nn.Linear(dim, out_channels * t * h * w) | |
| def forward( | |
| self, | |
| vid: torch.Tensor, | |
| ) -> torch.Tensor: | |
| t, h, w = self.patch_size | |
| vid = self.proj(vid) | |
| vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) | |
| return vid | |
| class NaPatchIn(PatchIn): | |
| def forward( | |
| self, | |
| vid: torch.Tensor, # l c | |
| vid_shape: torch.LongTensor, | |
| ) -> torch.Tensor: | |
| t, h, w = self.patch_size | |
| if not (t == h == w == 1): | |
| vid, vid_shape = na.rearrange( | |
| vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w | |
| ) | |
| # slice vid after patching in when using sequence parallelism | |
| vid = slice_inputs(vid, dim=0) | |
| vid = self.proj(vid) | |
| return vid, vid_shape | |
| class NaPatchOut(PatchOut): | |
| def forward( | |
| self, | |
| vid: torch.FloatTensor, # l c | |
| vid_shape: torch.LongTensor, | |
| cache: Cache = Cache(disable=True), | |
| ) -> Tuple[ | |
| torch.FloatTensor, | |
| torch.LongTensor, | |
| ]: | |
| t, h, w = self.patch_size | |
| vid = self.proj(vid) | |
| # gather vid before patching out when enabling sequence parallelism | |
| vid = gather_outputs( | |
| vid, | |
| gather_dim=0, | |
| padding_dim=0, | |
| unpad_shape=vid_shape, | |
| cache=cache.namespace("vid"), | |
| ) | |
| if not (t == h == w == 1): | |
| vid, vid_shape = na.rearrange( | |
| vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w | |
| ) | |
| return vid, vid_shape | |