Spaces:
Build error
Build error
""" | |
OmniGen2 Attention Processor Module | |
Copyright 2025 BAAI, The OmniGen2 Team and The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
""" | |
import math | |
from typing import Optional, Tuple, Dict, Any | |
import torch | |
import torch.nn.functional as F | |
from einops import repeat | |
from flash_attn import flash_attn_varlen_func | |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input | |
from diffusers.models.attention_processor import Attention | |
from .embeddings import apply_rotary_emb | |
class OmniGen2AttnProcessorFlash2Varlen: | |
""" | |
Processor for implementing scaled dot-product attention with flash attention and variable length sequences. | |
This processor is optimized for PyTorch 2.0 and implements: | |
- Flash attention with variable length sequences | |
- Rotary position embeddings (RoPE) | |
- Query-Key normalization | |
- Proportional attention scaling | |
Args: | |
None | |
Raises: | |
ImportError: If PyTorch version is less than 2.0 | |
""" | |
def __init__(self) -> None: | |
"""Initialize the attention processor.""" | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"OmniGen2AttnProcessorFlash2Varlen requires PyTorch 2.0. " | |
"Please upgrade PyTorch to version 2.0 or later." | |
) | |
def _upad_input( | |
self, | |
query_layer: torch.Tensor, | |
key_layer: torch.Tensor, | |
value_layer: torch.Tensor, | |
attention_mask: torch.Tensor, | |
query_length: int, | |
num_heads: int, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: | |
""" | |
Unpad the input tensors for flash attention. | |
Args: | |
query_layer: Query tensor of shape (batch_size, seq_len, num_heads, head_dim) | |
key_layer: Key tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) | |
value_layer: Value tensor of shape (batch_size, seq_len, num_kv_heads, head_dim) | |
attention_mask: Attention mask tensor of shape (batch_size, seq_len) | |
query_length: Length of the query sequence | |
num_heads: Number of attention heads | |
Returns: | |
Tuple containing: | |
- Unpadded query tensor | |
- Unpadded key tensor | |
- Unpadded value tensor | |
- Query indices | |
- Tuple of cumulative sequence lengths for query and key | |
- Tuple of maximum sequence lengths for query and key | |
""" | |
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: | |
"""Helper function to get unpadding data from attention mask.""" | |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
max_seqlen_in_batch = seqlens_in_batch.max().item() | |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
return indices, cu_seqlens, max_seqlen_in_batch | |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
# Unpad key and value layers | |
key_layer = index_first_axis( | |
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
indices_k, | |
) | |
value_layer = index_first_axis( | |
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
indices_k, | |
) | |
# Handle different query length cases | |
if query_length == kv_seq_len: | |
query_layer = index_first_axis( | |
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), | |
indices_k, | |
) | |
cu_seqlens_q = cu_seqlens_k | |
max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
indices_q = indices_k | |
elif query_length == 1: | |
max_seqlen_in_batch_q = 1 | |
cu_seqlens_q = torch.arange( | |
batch_size + 1, dtype=torch.int32, device=query_layer.device | |
) | |
indices_q = cu_seqlens_q[:-1] | |
query_layer = query_layer.squeeze(1) | |
else: | |
attention_mask = attention_mask[:, -query_length:] | |
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | |
return ( | |
query_layer, | |
key_layer, | |
value_layer, | |
indices_q, | |
(cu_seqlens_q, cu_seqlens_k), | |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
) | |
def __call__( | |
self, | |
attn: Attention, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
image_rotary_emb: Optional[torch.Tensor] = None, | |
base_sequence_length: Optional[int] = None, | |
) -> torch.Tensor: | |
""" | |
Process attention computation with flash attention. | |
Args: | |
attn: Attention module | |
hidden_states: Hidden states tensor of shape (batch_size, seq_len, hidden_dim) | |
encoder_hidden_states: Encoder hidden states tensor | |
attention_mask: Optional attention mask tensor | |
image_rotary_emb: Optional rotary embeddings for image tokens | |
base_sequence_length: Optional base sequence length for proportional attention | |
Returns: | |
torch.Tensor: Processed hidden states after attention computation | |
""" | |
batch_size, sequence_length, _ = hidden_states.shape | |
# Get Query-Key-Value Pair | |
query = attn.to_q(hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query_dim = query.shape[-1] | |
inner_dim = key.shape[-1] | |
head_dim = query_dim // attn.heads | |
dtype = query.dtype | |
# Get key-value heads | |
kv_heads = inner_dim // head_dim | |
# Reshape tensors for attention computation | |
query = query.view(batch_size, -1, attn.heads, head_dim) | |
key = key.view(batch_size, -1, kv_heads, head_dim) | |
value = value.view(batch_size, -1, kv_heads, head_dim) | |
# Apply Query-Key normalization | |
if attn.norm_q is not None: | |
query = attn.norm_q(query) | |
if attn.norm_k is not None: | |
key = attn.norm_k(key) | |
# Apply Rotary Position Embeddings | |
if image_rotary_emb is not None: | |
query = apply_rotary_emb(query, image_rotary_emb, use_real=False) | |
key = apply_rotary_emb(key, image_rotary_emb, use_real=False) | |
query, key = query.to(dtype), key.to(dtype) | |
# Calculate attention scale | |
if base_sequence_length is not None: | |
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale | |
else: | |
softmax_scale = attn.scale | |
# Unpad input for flash attention | |
( | |
query_states, | |
key_states, | |
value_states, | |
indices_q, | |
cu_seq_lens, | |
max_seq_lens, | |
) = self._upad_input(query, key, value, attention_mask, sequence_length, attn.heads) | |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
# Handle different number of heads | |
if kv_heads < attn.heads: | |
key_states = repeat(key_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) | |
value_states = repeat(value_states, "l h c -> l (h k) c", k=attn.heads // kv_heads) | |
# Apply flash attention | |
attn_output_unpad = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_in_batch_q, | |
max_seqlen_k=max_seqlen_in_batch_k, | |
dropout_p=0.0, | |
causal=False, | |
softmax_scale=softmax_scale, | |
) | |
# Pad output and apply final transformations | |
hidden_states = pad_input(attn_output_unpad, indices_q, batch_size, sequence_length) | |
hidden_states = hidden_states.flatten(-2) | |
hidden_states = hidden_states.type_as(query) | |
# Apply output projection | |
hidden_states = attn.to_out[0](hidden_states) | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states |