Spaces:
Build error
Build error
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import List, Optional, Tuple, Union | |
import torch | |
from torch import nn | |
from diffusers.models.activations import get_activation | |
class TimestepEmbedding(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
time_embed_dim: int, | |
act_fn: str = "silu", | |
out_dim: int = None, | |
post_act_fn: Optional[str] = None, | |
cond_proj_dim=None, | |
sample_proj_bias=True, | |
): | |
super().__init__() | |
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) | |
if cond_proj_dim is not None: | |
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) | |
else: | |
self.cond_proj = None | |
self.act = get_activation(act_fn) | |
if out_dim is not None: | |
time_embed_dim_out = out_dim | |
else: | |
time_embed_dim_out = time_embed_dim | |
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) | |
if post_act_fn is None: | |
self.post_act = None | |
else: | |
self.post_act = get_activation(post_act_fn) | |
self.initialize_weights() | |
def initialize_weights(self): | |
nn.init.normal_(self.linear_1.weight, std=0.02) | |
nn.init.zeros_(self.linear_1.bias) | |
nn.init.normal_(self.linear_2.weight, std=0.02) | |
nn.init.zeros_(self.linear_2.bias) | |
def forward(self, sample, condition=None): | |
if condition is not None: | |
sample = sample + self.cond_proj(condition) | |
sample = self.linear_1(sample) | |
if self.act is not None: | |
sample = self.act(sample) | |
sample = self.linear_2(sample) | |
if self.post_act is not None: | |
sample = self.post_act(sample) | |
return sample | |
def apply_rotary_emb( | |
x: torch.Tensor, | |
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], | |
use_real: bool = True, | |
use_real_unbind_dim: int = -1, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings | |
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are | |
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting | |
tensors contain rotary embeddings and are returned as real tensors. | |
Args: | |
x (`torch.Tensor`): | |
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply | |
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. | |
""" | |
if use_real: | |
cos, sin = freqs_cis # [S, D] | |
cos = cos[None, None] | |
sin = sin[None, None] | |
cos, sin = cos.to(x.device), sin.to(x.device) | |
if use_real_unbind_dim == -1: | |
# Used for flux, cogvideox, hunyuan-dit | |
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
elif use_real_unbind_dim == -2: | |
# Used for Stable Audio, OmniGen and CogView4 | |
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] | |
x_rotated = torch.cat([-x_imag, x_real], dim=-1) | |
else: | |
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") | |
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | |
return out | |
else: | |
# used for lumina | |
# x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) | |
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) | |
freqs_cis = freqs_cis.unsqueeze(2) | |
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) | |
return x_out.type_as(x) |