Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from typing import Callable, List, Optional | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from common.cache import Cache | |
| from common.distributed.ops import slice_inputs | |
| # (dim: int, emb_dim: int) | |
| ada_layer_type = Callable[[int, int], nn.Module] | |
| def get_ada_layer(ada_layer: str) -> ada_layer_type: | |
| if ada_layer == "single": | |
| return AdaSingle | |
| raise NotImplementedError(f"{ada_layer} is not supported") | |
| def expand_dims(x: torch.Tensor, dim: int, ndim: int): | |
| """ | |
| Expand tensor "x" to "ndim" by adding empty dims at "dim". | |
| Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). | |
| """ | |
| shape = x.shape | |
| shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] | |
| return x.reshape(shape) | |
| class AdaSingle(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| emb_dim: int, | |
| layers: List[str], | |
| ): | |
| assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" | |
| super().__init__() | |
| self.dim = dim | |
| self.emb_dim = emb_dim | |
| self.layers = layers | |
| for l in layers: | |
| self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) | |
| self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)) | |
| self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) | |
| def forward( | |
| self, | |
| hid: torch.FloatTensor, # b ... c | |
| emb: torch.FloatTensor, # b d | |
| layer: str, | |
| mode: str, | |
| cache: Cache = Cache(disable=True), | |
| branch_tag: str = "", | |
| hid_len: Optional[torch.LongTensor] = None, # b | |
| ) -> torch.FloatTensor: | |
| idx = self.layers.index(layer) | |
| emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] | |
| emb = expand_dims(emb, 1, hid.ndim + 1) | |
| if hid_len is not None: | |
| emb = cache( | |
| f"emb_repeat_{idx}_{branch_tag}", | |
| lambda: slice_inputs( | |
| torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), | |
| dim=0, | |
| ), | |
| ) | |
| shiftA, scaleA, gateA = emb.unbind(-1) | |
| shiftB, scaleB, gateB = ( | |
| getattr(self, f"{layer}_shift"), | |
| getattr(self, f"{layer}_scale"), | |
| getattr(self, f"{layer}_gate"), | |
| ) | |
| if mode == "in": | |
| return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) | |
| if mode == "out": | |
| return hid.mul_(gateA + gateB) | |
| raise NotImplementedError | |
| def extra_repr(self) -> str: | |
| return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" |