Spaces:
Paused
Paused
| # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # // | |
| # // Licensed under the Apache License, Version 2.0 (the "License"); | |
| # // you may not use this file except in compliance with the License. | |
| # // You may obtain a copy of the License at | |
| # // | |
| # // http://www.apache.org/licenses/LICENSE-2.0 | |
| # // | |
| # // Unless required by applicable law or agreed to in writing, software | |
| # // distributed under the License is distributed on an "AS IS" BASIS, | |
| # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # // See the License for the specific language governing permissions and | |
| # // limitations under the License. | |
| from typing import Optional, Union | |
| import torch | |
| from diffusers.models.embeddings import get_timestep_embedding | |
| from torch import nn | |
| def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): | |
| return emb1 if emb2 is None else emb1 + emb2 | |
| class TimeEmbedding(nn.Module): | |
| def __init__( | |
| self, | |
| sinusoidal_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| ): | |
| super().__init__() | |
| self.sinusoidal_dim = sinusoidal_dim | |
| self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) | |
| self.proj_hid = nn.Linear(hidden_dim, hidden_dim) | |
| self.proj_out = nn.Linear(hidden_dim, output_dim) | |
| self.act = nn.SiLU() | |
| def forward( | |
| self, | |
| timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> torch.FloatTensor: | |
| if not torch.is_tensor(timestep): | |
| timestep = torch.tensor([timestep], device=device, dtype=dtype) | |
| if timestep.ndim == 0: | |
| timestep = timestep[None] | |
| emb = get_timestep_embedding( | |
| timesteps=timestep, | |
| embedding_dim=self.sinusoidal_dim, | |
| flip_sin_to_cos=False, | |
| downscale_freq_shift=0, | |
| ) | |
| emb = emb.to(dtype) | |
| emb = self.proj_in(emb) | |
| emb = self.act(emb) | |
| emb = self.proj_hid(emb) | |
| emb = self.act(emb) | |
| emb = self.proj_out(emb) | |
| return emb | |