|
from typing import List, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from einops import repeat |
|
from diffusers.models.embeddings import get_1d_rotary_pos_embed |
|
|
|
class OmniGen2RotaryPosEmbed(nn.Module): |
|
def __init__(self, theta: int, |
|
axes_dim: Tuple[int, int, int], |
|
axes_lens: Tuple[int, int, int] = (300, 512, 512), |
|
patch_size: int = 2): |
|
super().__init__() |
|
self.theta = theta |
|
self.axes_dim = axes_dim |
|
self.axes_lens = axes_lens |
|
self.patch_size = patch_size |
|
|
|
@staticmethod |
|
def get_freqs_cis(axes_dim: Tuple[int, int, int], |
|
axes_lens: Tuple[int, int, int], |
|
theta: int) -> List[torch.Tensor]: |
|
freqs_cis = [] |
|
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 |
|
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): |
|
emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) |
|
freqs_cis.append(emb) |
|
return freqs_cis |
|
|
|
def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: |
|
device = ids.device |
|
if ids.device.type == "mps": |
|
ids = ids.to("cpu") |
|
|
|
result = [] |
|
for i in range(len(self.axes_dim)): |
|
freqs = freqs_cis[i].to(ids.device) |
|
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) |
|
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) |
|
return torch.cat(result, dim=-1).to(device) |
|
|
|
def forward( |
|
self, |
|
freqs_cis, |
|
attention_mask, |
|
l_effective_ref_img_len, |
|
l_effective_img_len, |
|
ref_img_sizes, |
|
img_sizes, |
|
device |
|
): |
|
batch_size = len(attention_mask) |
|
p = self.patch_size |
|
|
|
encoder_seq_len = attention_mask.shape[1] |
|
l_effective_cap_len = attention_mask.sum(dim=1).tolist() |
|
|
|
seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] |
|
|
|
max_seq_len = max(seq_lengths) |
|
max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) |
|
max_img_len = max(l_effective_img_len) |
|
|
|
|
|
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) |
|
|
|
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): |
|
|
|
position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") |
|
|
|
pe_shift = cap_seq_len |
|
pe_shift_len = cap_seq_len |
|
|
|
if ref_img_sizes[i] is not None: |
|
for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): |
|
H, W = ref_img_size |
|
ref_H_tokens, ref_W_tokens = H // p, W // p |
|
assert ref_H_tokens * ref_W_tokens == ref_img_len |
|
|
|
|
|
row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() |
|
col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() |
|
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift |
|
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids |
|
position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids |
|
|
|
pe_shift += max(ref_H_tokens, ref_W_tokens) |
|
pe_shift_len += ref_img_len |
|
|
|
H, W = img_sizes[i] |
|
H_tokens, W_tokens = H // p, W // p |
|
assert H_tokens * W_tokens == l_effective_img_len[i] |
|
|
|
row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() |
|
col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() |
|
|
|
assert pe_shift_len + l_effective_img_len[i] == seq_len |
|
position_ids[i, pe_shift_len: seq_len, 0] = pe_shift |
|
position_ids[i, pe_shift_len: seq_len, 1] = row_ids |
|
position_ids[i, pe_shift_len: seq_len, 2] = col_ids |
|
|
|
|
|
freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) |
|
|
|
|
|
cap_freqs_cis = torch.zeros( |
|
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype |
|
) |
|
ref_img_freqs_cis = torch.zeros( |
|
batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype |
|
) |
|
img_freqs_cis = torch.zeros( |
|
batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype |
|
) |
|
|
|
for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): |
|
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] |
|
ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] |
|
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] |
|
|
|
return ( |
|
cap_freqs_cis, |
|
ref_img_freqs_cis, |
|
img_freqs_cis, |
|
freqs_cis, |
|
l_effective_cap_len, |
|
seq_lengths, |
|
) |