from typing import List, Tuple import torch import torch.nn as nn from einops import repeat from diffusers.models.embeddings import get_1d_rotary_pos_embed class OmniGen2RotaryPosEmbed(nn.Module): def __init__(self, theta: int, axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int] = (300, 512, 512), patch_size: int = 2): super().__init__() self.theta = theta self.axes_dim = axes_dim self.axes_lens = axes_lens self.patch_size = patch_size @staticmethod def get_freqs_cis(axes_dim: Tuple[int, int, int], axes_lens: Tuple[int, int, int], theta: int) -> List[torch.Tensor]: freqs_cis = [] freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): emb = get_1d_rotary_pos_embed(d, e, theta=theta, freqs_dtype=freqs_dtype) freqs_cis.append(emb) return freqs_cis def _get_freqs_cis(self, freqs_cis, ids: torch.Tensor) -> torch.Tensor: device = ids.device if ids.device.type == "mps": ids = ids.to("cpu") result = [] for i in range(len(self.axes_dim)): freqs = freqs_cis[i].to(ids.device) index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) return torch.cat(result, dim=-1).to(device) def forward( self, freqs_cis, attention_mask, l_effective_ref_img_len, l_effective_img_len, ref_img_sizes, img_sizes, device ): batch_size = len(attention_mask) p = self.patch_size encoder_seq_len = attention_mask.shape[1] l_effective_cap_len = attention_mask.sum(dim=1).tolist() seq_lengths = [cap_len + sum(ref_img_len) + img_len for cap_len, ref_img_len, img_len in zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len)] max_seq_len = max(seq_lengths) max_ref_img_len = max([sum(ref_img_len) for ref_img_len in l_effective_ref_img_len]) max_img_len = max(l_effective_img_len) # Create position IDs position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)): # add text position ids position_ids[i, :cap_seq_len] = repeat(torch.arange(cap_seq_len, dtype=torch.int32, device=device), "l -> l 3") pe_shift = cap_seq_len pe_shift_len = cap_seq_len if ref_img_sizes[i] is not None: for ref_img_size, ref_img_len in zip(ref_img_sizes[i], l_effective_ref_img_len[i]): H, W = ref_img_size ref_H_tokens, ref_W_tokens = H // p, W // p assert ref_H_tokens * ref_W_tokens == ref_img_len # add image position ids row_ids = repeat(torch.arange(ref_H_tokens, dtype=torch.int32, device=device), "h -> h w", w=ref_W_tokens).flatten() col_ids = repeat(torch.arange(ref_W_tokens, dtype=torch.int32, device=device), "w -> h w", h=ref_H_tokens).flatten() position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 0] = pe_shift position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 1] = row_ids position_ids[i, pe_shift_len:pe_shift_len + ref_img_len, 2] = col_ids pe_shift += max(ref_H_tokens, ref_W_tokens) pe_shift_len += ref_img_len H, W = img_sizes[i] H_tokens, W_tokens = H // p, W // p assert H_tokens * W_tokens == l_effective_img_len[i] row_ids = repeat(torch.arange(H_tokens, dtype=torch.int32, device=device), "h -> h w", w=W_tokens).flatten() col_ids = repeat(torch.arange(W_tokens, dtype=torch.int32, device=device), "w -> h w", h=H_tokens).flatten() assert pe_shift_len + l_effective_img_len[i] == seq_len position_ids[i, pe_shift_len: seq_len, 0] = pe_shift position_ids[i, pe_shift_len: seq_len, 1] = row_ids position_ids[i, pe_shift_len: seq_len, 2] = col_ids # Get combined rotary embeddings freqs_cis = self._get_freqs_cis(freqs_cis, position_ids) # create separate rotary embeddings for captions and images cap_freqs_cis = torch.zeros( batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) ref_img_freqs_cis = torch.zeros( batch_size, max_ref_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) img_freqs_cis = torch.zeros( batch_size, max_img_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype ) for i, (cap_seq_len, ref_img_len, img_len, seq_len) in enumerate(zip(l_effective_cap_len, l_effective_ref_img_len, l_effective_img_len, seq_lengths)): cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len] ref_img_freqs_cis[i, :sum(ref_img_len)] = freqs_cis[i, cap_seq_len:cap_seq_len + sum(ref_img_len)] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_seq_len + sum(ref_img_len):cap_seq_len + sum(ref_img_len) + img_len] return ( cap_freqs_cis, ref_img_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths, )