|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
|
|
from diffusers.models.activations import get_activation |
|
|
|
|
|
class TimestepEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
in_channels: int, |
|
time_embed_dim: int, |
|
act_fn: str = "silu", |
|
out_dim: int = None, |
|
post_act_fn: Optional[str] = None, |
|
cond_proj_dim=None, |
|
sample_proj_bias=True, |
|
): |
|
super().__init__() |
|
|
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) |
|
|
|
if cond_proj_dim is not None: |
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) |
|
else: |
|
self.cond_proj = None |
|
|
|
self.act = get_activation(act_fn) |
|
|
|
if out_dim is not None: |
|
time_embed_dim_out = out_dim |
|
else: |
|
time_embed_dim_out = time_embed_dim |
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) |
|
|
|
if post_act_fn is None: |
|
self.post_act = None |
|
else: |
|
self.post_act = get_activation(post_act_fn) |
|
|
|
self.initialize_weights() |
|
|
|
def initialize_weights(self): |
|
nn.init.normal_(self.linear_1.weight, std=0.02) |
|
nn.init.zeros_(self.linear_1.bias) |
|
nn.init.normal_(self.linear_2.weight, std=0.02) |
|
nn.init.zeros_(self.linear_2.bias) |
|
|
|
def forward(self, sample, condition=None): |
|
if condition is not None: |
|
sample = sample + self.cond_proj(condition) |
|
sample = self.linear_1(sample) |
|
|
|
if self.act is not None: |
|
sample = self.act(sample) |
|
|
|
sample = self.linear_2(sample) |
|
|
|
if self.post_act is not None: |
|
sample = self.post_act(sample) |
|
return sample |
|
|
|
|
|
def apply_rotary_emb( |
|
x: torch.Tensor, |
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], |
|
use_real: bool = True, |
|
use_real_unbind_dim: int = -1, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings |
|
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are |
|
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting |
|
tensors contain rotary embeddings and are returned as real tensors. |
|
|
|
Args: |
|
x (`torch.Tensor`): |
|
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply |
|
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. |
|
""" |
|
if use_real: |
|
cos, sin = freqs_cis |
|
cos = cos[None, None] |
|
sin = sin[None, None] |
|
cos, sin = cos.to(x.device), sin.to(x.device) |
|
|
|
if use_real_unbind_dim == -1: |
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) |
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) |
|
elif use_real_unbind_dim == -2: |
|
|
|
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) |
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1) |
|
else: |
|
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") |
|
|
|
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) |
|
|
|
return out |
|
else: |
|
|
|
|
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], x.shape[-1] // 2, 2)) |
|
freqs_cis = freqs_cis.unsqueeze(2) |
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) |
|
|
|
return x_out.type_as(x) |