""" OmniGen2 Attention Processor Module Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math from typing import Optional, Tuple, Dict, Any import torch import torch.nn.functional as F from einops import repeat from flash_attn import flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input from diffusers.models.attention_processor import Attention from .embeddings import apply_rotary_emb class OmniGen2AttnProcessorFlash2Varlen: """ Processor for implementing scaled dot-product attention with flash attention and variable length sequences. This processor is optimized for PyTorch 2.0 and implements: - Flash attention with variable length sequences - Rotary position embeddings (RoPE) - Query-Key normalization - Proportional attention scaling Args: None Raises: ImportError: If PyTorch version is less than 2.0 """ def __init__(self) -> None: """Initialize the attention processor.""" if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( "OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " "Please upgrade PyTorch to version 2.0 or later." ) def _upad_input( self, query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask: torch.Tensor, query_length: int, num_heads: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: """ Unpad the input tensors for flash attention. Args: query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) attention_mask: Attention mask tensor of shape (batch_size, seq_len) query_length: Length of the query sequence num_heads: Number of attention heads Returns: Tuple containing: - Unpadded query tensor - Unpadded key tensor - Unpadded value tensor - Query indices - Tuple of cumulative sequence lengths for query and key - Tuple of maximum sequence lengths for query and key """ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: """Helper function to get unpadding data from attention mask.""" seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return indices, cu_seqlens, max_seqlen_in_batch indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape # Unpad key and value layers key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) # Handle different query length cases if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) def __call__( self, attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: """ Process attention computation with flash attention. Args: attn: Attention module hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) encoder_hidden_states: Encoder hidden states tensor attention_mask: Optional attention mask tensor image_rotary_emb: Optional rotary embeddings for image tokens base_sequence_length: Optional base sequence length for proportional attention Returns: torch.Tensor: Processed hidden states after attention computation """ batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair query = attn.to_q(hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query_dim = query.shape[-1] inner_dim = key.shape[-1] head_dim = query_dim // attn.heads dtype = query.dtype # Get key-value heads kv_heads = inner_dim // head_dim # Reshape tensors for attention computation query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) # Apply Query-Key normalization if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # Apply Rotary Position Embeddings if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) key = apply_rotary_emb(key, image_rotary_emb, use_real=False) query, key = query.to(dtype), key.to(dtype) # Calculate attention scale if base_sequence_length is not None: softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale else: softmax_scale = attn.scale # Unpad input for flash attention ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens # Handle different number of heads if kv_heads < attn.heads: key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) # Apply flash attention attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=0.0, causal=False, softmax_scale=softmax_scale, ) # Pad output and apply final transformations hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) hidden_states = hidden_states.flatten(-2) hidden_states = hidden_states.type_as(query) # Apply output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) return hidden_states