# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from __future__ import annotations from typing import TYPE_CHECKING, Optional, Tuple, Union import torch import torch.nn as nn from einops import rearrange from transformers.utils import logging from fla.modules import RotaryEmbedding from fla.ops.nsa.parallel import parallel_nsa if TYPE_CHECKING: from fla.models.utils import Cache logger = logging.get_logger(__name__) class NativeSparseAttention(nn.Module): def __init__( self, hidden_size: int = 2048, num_heads: int = 64, num_kv_heads: Optional[int] = 4, head_dim: int = 64, qkv_bias: bool = False, block_size: Optional[int] = 64, block_counts: Optional[Union[torch.LongTensor, int]] = 16, window_size: Optional[int] = 512, rope_theta: Optional[float] = 10000., max_position_embeddings: Optional[int] = None, layer_idx: int = None ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads if num_kv_heads is None: self.num_kv_heads = self.num_heads else: self.num_kv_heads = num_kv_heads self.num_kv_groups = num_heads // self.num_kv_heads self.head_dim = head_dim self.kv_dim = self.num_kv_heads * self.head_dim self.qkv_bias = qkv_bias self.block_size = block_size self.block_counts = block_counts self.window_size = window_size self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.layer_idx = layer_idx self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias) self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) batch_size, seq_len, _ = hidden_states.size() q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3) g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1) cu_seqlens = kwargs.get('cu_seqlens', None) seqlen_offset, max_seqlen = 0, seq_len if past_key_values is not None: seqlen_offset = past_key_values.get_seq_length(self.layer_idx) max_seqlen = q.shape[1] + seqlen_offset if attention_mask is not None: # to deliminate the offsets of padding tokens seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] max_seqlen = q.shape[1] + max(seqlen_offset) if self.max_position_embeddings is not None: max_seqlen = max(max_seqlen, self.max_position_embeddings) q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) if past_key_values is not None: cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 k_cached, v_cached = past_key_values.update( attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), layer_idx=self.layer_idx, offset=seq_len, cache_kwargs=dict(window_size=self.window_size) )['attn_state'] if cache_has_content: k, v = k_cached, v_cached k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) o = parallel_nsa( q=q, k=k, v=v, g_cmp=g_cmp, g_slc=g_slc, g_swa=g_swa, block_size=self.block_size, block_counts=self.block_counts, window_size=self.window_size, cu_seqlens=cu_seqlens, head_first=False ) o = o.reshape(batch_size, seq_len, -1) o = self.o_proj(o) if not output_attentions: attentions = None return o, attentions, past_key_values