File size: 5,303 Bytes
4135502 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Tuple, Union
import torch
import torch.nn as nn
from einops import rearrange
from transformers.utils import logging
from fla.modules import RotaryEmbedding
from fla.ops.nsa.parallel import parallel_nsa
if TYPE_CHECKING:
from fla.models.utils import Cache
logger = logging.get_logger(__name__)
class NativeSparseAttention(nn.Module):
def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 64,
num_kv_heads: Optional[int] = 4,
head_dim: int = 64,
qkv_bias: bool = False,
block_size: Optional[int] = 64,
block_counts: Optional[Union[torch.LongTensor, int]] = 16,
window_size: Optional[int] = 512,
rope_theta: Optional[float] = 10000.,
max_position_embeddings: Optional[int] = None,
layer_idx: int = None
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
if num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = num_kv_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.head_dim = head_dim
self.kv_dim = self.num_kv_heads * self.head_dim
self.qkv_bias = qkv_bias
self.block_size = block_size
self.block_counts = block_counts
self.window_size = window_size
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias)
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias)
self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if attention_mask is not None:
assert len(attention_mask.shape) == 2, (
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
"for padding purposes (0 indicating padding). "
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
)
batch_size, seq_len, _ = hidden_states.size()
q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim)
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3)
g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1)
cu_seqlens = kwargs.get('cu_seqlens', None)
seqlen_offset, max_seqlen = 0, seq_len
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset
if attention_mask is not None:
# to deliminate the offsets of padding tokens
seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]
max_seqlen = q.shape[1] + max(seqlen_offset)
if self.max_position_embeddings is not None:
max_seqlen = max(max_seqlen, self.max_position_embeddings)
q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens)
if past_key_values is not None:
cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0
k_cached, v_cached = past_key_values.update(
attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)),
layer_idx=self.layer_idx,
offset=seq_len,
cache_kwargs=dict(window_size=self.window_size)
)['attn_state']
if cache_has_content:
k, v = k_cached, v_cached
k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim)
v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
o = parallel_nsa(
q=q,
k=k,
v=v,
g_cmp=g_cmp,
g_slc=g_slc,
g_swa=g_swa,
block_size=self.block_size,
block_counts=self.block_counts,
window_size=self.window_size,
cu_seqlens=cu_seqlens,
head_first=False
)
o = o.reshape(batch_size, seq_len, -1)
o = self.o_proj(o)
if not output_attentions:
attentions = None
return o, attentions, past_key_values
|