|
""" |
|
OmniGen2 Attention Processor Module |
|
|
|
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
""" |
|
|
|
import math |
|
from typing import Optional, Tuple, Dict, Any |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import repeat |
|
from flash_attn import flash_attn_varlen_func |
|
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input |
|
|
|
from diffusers.models.attention_processor import Attention |
|
from .embeddings import apply_rotary_emb |
|
|
|
|
|
class OmniGen2AttnProcessorFlash2Varlen: |
|
""" |
|
Processor for implementing scaled dot-product attention with flash attention and variable length sequences. |
|
|
|
This processor is optimized for PyTorch 2.0 and implements: |
|
- Flash attention with variable length sequences |
|
- Rotary position embeddings (RoPE) |
|
- Query-Key normalization |
|
- Proportional attention scaling |
|
|
|
Args: |
|
None |
|
|
|
Raises: |
|
ImportError: If PyTorch version is less than 2.0 |
|
""" |
|
|
|
def __init__(self) -> None: |
|
"""Initialize the attention processor.""" |
|
if not hasattr(F, "scaled_dot_product_attention"): |
|
raise ImportError( |
|
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " |
|
"Please upgrade PyTorch to version 2.0 or later." |
|
) |
|
|
|
def _upad_input( |
|
self, |
|
query_layer: torch.Tensor, |
|
key_layer: torch.Tensor, |
|
value_layer: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
query_length: int, |
|
num_heads: int, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: |
|
""" |
|
Unpad the input tensors for flash attention. |
|
|
|
Args: |
|
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) |
|
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) |
|
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) |
|
attention_mask: Attention mask tensor of shape (batch_size, seq_len) |
|
query_length: Length of the query sequence |
|
num_heads: Number of attention heads |
|
|
|
Returns: |
|
Tuple containing: |
|
- Unpadded query tensor |
|
- Unpadded key tensor |
|
- Unpadded value tensor |
|
- Query indices |
|
- Tuple of cumulative sequence lengths for query and key |
|
- Tuple of maximum sequence lengths for query and key |
|
""" |
|
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: |
|
"""Helper function to get unpadding data from attention mask.""" |
|
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
|
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
max_seqlen_in_batch = seqlens_in_batch.max().item() |
|
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
return indices, cu_seqlens, max_seqlen_in_batch |
|
|
|
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) |
|
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape |
|
|
|
|
|
key_layer = index_first_axis( |
|
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
|
indices_k, |
|
) |
|
value_layer = index_first_axis( |
|
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), |
|
indices_k, |
|
) |
|
|
|
|
|
if query_length == kv_seq_len: |
|
query_layer = index_first_axis( |
|
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), |
|
indices_k, |
|
) |
|
cu_seqlens_q = cu_seqlens_k |
|
max_seqlen_in_batch_q = max_seqlen_in_batch_k |
|
indices_q = indices_k |
|
elif query_length == 1: |
|
max_seqlen_in_batch_q = 1 |
|
cu_seqlens_q = torch.arange( |
|
batch_size + 1, dtype=torch.int32, device=query_layer.device |
|
) |
|
indices_q = cu_seqlens_q[:-1] |
|
query_layer = query_layer.squeeze(1) |
|
else: |
|
attention_mask = attention_mask[:, -query_length:] |
|
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) |
|
|
|
return ( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
indices_q, |
|
(cu_seqlens_q, cu_seqlens_k), |
|
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), |
|
) |
|
|
|
def __call__( |
|
self, |
|
attn: Attention, |
|
hidden_states: torch.Tensor, |
|
encoder_hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
image_rotary_emb: Optional[torch.Tensor] = None, |
|
base_sequence_length: Optional[int] = None, |
|
) -> torch.Tensor: |
|
""" |
|
Process attention computation with flash attention. |
|
|
|
Args: |
|
attn: Attention module |
|
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) |
|
encoder_hidden_states: Encoder hidden states tensor |
|
attention_mask: Optional attention mask tensor |
|
image_rotary_emb: Optional rotary embeddings for image tokens |
|
base_sequence_length: Optional base sequence length for proportional attention |
|
|
|
Returns: |
|
torch.Tensor: Processed hidden states after attention computation |
|
""" |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
|
|
query = attn.to_q(hidden_states) |
|
key = attn.to_k(encoder_hidden_states) |
|
value = attn.to_v(encoder_hidden_states) |
|
|
|
query_dim = query.shape[-1] |
|
inner_dim = key.shape[-1] |
|
head_dim = query_dim // attn.heads |
|
dtype = query.dtype |
|
|
|
|
|
kv_heads = inner_dim // head_dim |
|
|
|
|
|
query = query.view(batch_size, -1, attn.heads, head_dim) |
|
key = key.view(batch_size, -1, kv_heads, head_dim) |
|
value = value.view(batch_size, -1, kv_heads, head_dim) |
|
|
|
|
|
if attn.norm_q is not None: |
|
query = attn.norm_q(query) |
|
if attn.norm_k is not None: |
|
key = attn.norm_k(key) |
|
|
|
|
|
if image_rotary_emb is not None: |
|
query = apply_rotary_emb(query, image_rotary_emb, use_real=False) |
|
key = apply_rotary_emb(key, image_rotary_emb, use_real=False) |
|
|
|
query, key = query.to(dtype), key.to(dtype) |
|
|
|
|
|
if base_sequence_length is not None: |
|
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale |
|
else: |
|
softmax_scale = attn.scale |
|
|
|
|
|
( |
|
query_states, |
|
key_states, |
|
value_states, |
|
indices_q, |
|
cu_seq_lens, |
|
max_seq_lens, |
|
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) |
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens |
|
|
|
|
|
if kv_heads < attn.heads: |
|
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) |
|
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) |
|
|
|
|
|
attn_output_unpad = flash_attn_varlen_func( |
|
query_states, |
|
key_states, |
|
value_states, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_k=cu_seqlens_k, |
|
max_seqlen_q=max_seqlen_in_batch_q, |
|
max_seqlen_k=max_seqlen_in_batch_k, |
|
dropout_p=0.0, |
|
causal=False, |
|
softmax_scale=softmax_scale, |
|
) |
|
|
|
|
|
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) |
|
hidden_states = hidden_states.flatten(-2) |
|
hidden_states = hidden_states.type_as(query) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) |
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
return hidden_states |