diff --git a/fla/__pycache__/__init__.cpython-311.pyc b/fla/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edf3e8ef03bbd0c2b20c6539020edc4e5b777aed Binary files /dev/null and b/fla/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/__pycache__/utils.cpython-311.pyc b/fla/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38f69b2c23a4655038e334dd3eb9ae5f1eb5bb71 Binary files /dev/null and b/fla/__pycache__/utils.cpython-311.pyc differ diff --git a/fla/layers/__init__.py b/fla/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16b64d5b84beda1d06fc7334f6fea2e5aba6a7fc --- /dev/null +++ b/fla/layers/__init__.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from .abc import ABCAttention +from .attn import Attention +from .based import BasedLinearAttention +from .bitattn import BitAttention +from .delta_net import DeltaNet +from .forgetting_attn import ForgettingAttention +from .gated_deltanet import GatedDeltaNet +from .gated_deltaproduct import GatedDeltaProduct +from .gla import GatedLinearAttention +from .gsa import GatedSlotAttention +from .hgrn import HGRNAttention +from .hgrn2 import HGRN2Attention +from .lightnet import LightNetAttention +from .linear_attn import LinearAttention +from .multiscale_retention import MultiScaleRetention +from .nsa import NativeSparseAttention +from .rebased import ReBasedLinearAttention +from .rwkv6 import RWKV6Attention +from .rwkv7 import RWKV7Attention + +__all__ = [ + 'ABCAttention', + 'Attention', + 'BasedLinearAttention', + 'BitAttention', + 'DeltaNet', + 'ForgettingAttention', + 'GatedDeltaNet', + 'GatedDeltaProduct', + 'GatedLinearAttention', + 'GatedSlotAttention', + 'HGRNAttention', + 'HGRN2Attention', + 'LightNetAttention', + 'LinearAttention', + 'MultiScaleRetention', + 'NativeSparseAttention', + 'ReBasedLinearAttention', + 'RWKV6Attention', + 'RWKV7Attention', +] diff --git a/fla/layers/__pycache__/__init__.cpython-311.pyc b/fla/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13eb7470d4543d357bdc769fdb1245944e65971f Binary files /dev/null and b/fla/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/abc.cpython-311.pyc b/fla/layers/__pycache__/abc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..144c4877405a622166ee70263c64b63076991e05 Binary files /dev/null and b/fla/layers/__pycache__/abc.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/attn.cpython-311.pyc b/fla/layers/__pycache__/attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4901bb55096a003c02e6e16fcf0355884fe73bb9 Binary files /dev/null and b/fla/layers/__pycache__/attn.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/based.cpython-311.pyc b/fla/layers/__pycache__/based.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa3cce60fc50071e4b44ed93e92b3f1b22943167 Binary files /dev/null and b/fla/layers/__pycache__/based.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/bitattn.cpython-311.pyc b/fla/layers/__pycache__/bitattn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02f16a7638f6e17ea01729c00f176d1a10002d1f Binary files /dev/null and b/fla/layers/__pycache__/bitattn.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/delta_net.cpython-311.pyc b/fla/layers/__pycache__/delta_net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247b4e085a9ac987822ad1a2551ef213eb8bc8a2 Binary files /dev/null and b/fla/layers/__pycache__/delta_net.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/forgetting_attn.cpython-311.pyc b/fla/layers/__pycache__/forgetting_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6044b0adf6a90f3a18442660a78a67c1943f96d1 Binary files /dev/null and b/fla/layers/__pycache__/forgetting_attn.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/gated_deltanet.cpython-311.pyc b/fla/layers/__pycache__/gated_deltanet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c555ae7c9a2652e054c362416b9e84913dd3ecf2 Binary files /dev/null and b/fla/layers/__pycache__/gated_deltanet.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc b/fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cbb172349104c594b4c05157ae01cb65cf21520 Binary files /dev/null and b/fla/layers/__pycache__/gated_deltaproduct.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/gla.cpython-311.pyc b/fla/layers/__pycache__/gla.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83c9514b002c05a35df8362ba447644674f59a0c Binary files /dev/null and b/fla/layers/__pycache__/gla.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/gsa.cpython-311.pyc b/fla/layers/__pycache__/gsa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc7eb7cba28adfb6afe69be4c92978972041677c Binary files /dev/null and b/fla/layers/__pycache__/gsa.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/hgrn.cpython-311.pyc b/fla/layers/__pycache__/hgrn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f63e42e813120f49fe083476eff28cb2d29d6ba Binary files /dev/null and b/fla/layers/__pycache__/hgrn.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/hgrn2.cpython-311.pyc b/fla/layers/__pycache__/hgrn2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9758111e5dc8e2865a43866c869550cee7a8a509 Binary files /dev/null and b/fla/layers/__pycache__/hgrn2.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/lightnet.cpython-311.pyc b/fla/layers/__pycache__/lightnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8ef136d66a2e3396ba7a16307b2afcae94fd773 Binary files /dev/null and b/fla/layers/__pycache__/lightnet.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/linear_attn.cpython-311.pyc b/fla/layers/__pycache__/linear_attn.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2eaaeca17d2a47f0af299344575aadbdb672ce71 Binary files /dev/null and b/fla/layers/__pycache__/linear_attn.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/multiscale_retention.cpython-311.pyc b/fla/layers/__pycache__/multiscale_retention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95bf49f76d7242c95eef6751a6c9361425a1ad5d Binary files /dev/null and b/fla/layers/__pycache__/multiscale_retention.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/nsa.cpython-311.pyc b/fla/layers/__pycache__/nsa.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508df44b895dee49b512320f2df5b7b8d32045b2 Binary files /dev/null and b/fla/layers/__pycache__/nsa.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/rebased.cpython-311.pyc b/fla/layers/__pycache__/rebased.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36209e572b8bf87f12c2ed3b40302a3d99e648f6 Binary files /dev/null and b/fla/layers/__pycache__/rebased.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/rwkv6.cpython-311.pyc b/fla/layers/__pycache__/rwkv6.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52bc8a51449887d766ddb52989bcd2ddfceff3a0 Binary files /dev/null and b/fla/layers/__pycache__/rwkv6.cpython-311.pyc differ diff --git a/fla/layers/__pycache__/rwkv7.cpython-311.pyc b/fla/layers/__pycache__/rwkv7.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6a6be1e673ef47ad516f902a8768d165e5b624c Binary files /dev/null and b/fla/layers/__pycache__/rwkv7.cpython-311.pyc differ diff --git a/fla/layers/abc.py b/fla/layers/abc.py new file mode 100644 index 0000000000000000000000000000000000000000..3afc3ebd015f4414edf26474f3f0a98b0d18e0e4 --- /dev/null +++ b/fla/layers/abc.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules import FusedRMSNormGated, RMSNorm, RotaryEmbedding, ShortConvolution +from fla.modules.activations import swiglu, swish +from fla.ops.abc.chunk import chunk_abc + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class ABCAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + num_slots: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_low_rank_dim: int = 16, + gate_logit_normalizer: int = 16, + use_rope: bool = True, + use_input_gate: bool = False, + use_output_gate: bool = True, + use_norm: bool = True, + clamp_min: Optional[float] = -32, + clamp_max: Optional[float] = 32, + layer_idx: Optional[int] = None, + **kwargs + ) -> ABCAttention: + super().__init__() + + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.key_dim = int(self.hidden_size * self.expand_k) + self.value_dim = int(self.hidden_size * self.expand_v) + self.head_k_dim = self.key_dim // self.num_heads + self.head_v_dim = self.value_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.gate_low_rank_dim = gate_low_rank_dim + self.gate_logit_normalizer = gate_logit_normalizer + + self.use_rope = use_rope + self.use_input_gate = use_input_gate + self.use_output_gate = use_output_gate + self.use_norm = use_norm + + if num_slots is None: + num_slots = self.head_k_dim + self.num_slots = num_slots + + self.norm_eps = norm_eps + + self.clamp_min = clamp_min + self.clamp_max = clamp_max + self.layer_idx = layer_idx + + if layer_idx is None: + warnings.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) + + if use_output_gate: + self.g_proj = nn.Linear(self.hidden_size, self.value_dim, bias=False) + self.s_proj = nn.Linear(self.hidden_size, self.num_heads * self.num_slots, bias=False) + self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation='silu') + + if self.use_norm: + if self.use_output_gate: + self.g_norm = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + else: + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + + if self.use_rope: + self.rotary = RotaryEmbedding(self.head_k_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if cu_seqlens is not None: + raise NotImplementedError("Training with cu_seqlens is not supported yet for ABCAttention") + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.use_input_gate: + q, k, v = map(lambda x: swish(x), (q, k, v)) + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + if self.use_rope: + seqlen_offset = 0 + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset) + + s = rearrange(self.s_proj(hidden_states), '... (h m) -> ... h m', m=self.num_slots) + s = s.clamp_(self.clamp_min, self.clamp_max) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + o, recurrent_state = chunk_abc( + q=q, + k=k, + v=v, + s=s, + initial_state=recurrent_state, + output_final_state=use_cache, + head_first=False + ) + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + if self.use_norm and not self.use_output_gate: + o = self.g_norm(o) + elif self.use_output_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.g_norm(o, g) if self.use_norm else swiglu(g, o) + o = rearrange(o, '... h d -> ... (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, seq_len: int = 2048): + return 2 * self.num_slots * self.hidden_size diff --git a/fla/layers/attn.py b/fla/layers/attn.py new file mode 100644 index 0000000000000000000000000000000000000000..723cd5450a90912d31b0d98c18e2c9b1113cf427 --- /dev/null +++ b/fla/layers/attn.py @@ -0,0 +1,222 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.modules import RMSNorm, RotaryEmbedding +from fla.ops import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn + +if TYPE_CHECKING: + from fla.models.utils import Cache + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + +logger = logging.get_logger(__name__) + + +class Attention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: Optional[int] = None, + layer_idx: int = None, + attn_impl: str = "flash_attn", + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + self.attn_impl = attn_impl + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if qk_norm: + self.q_norm = RMSNorm(self.head_dim) + self.k_norm = RMSNorm(self.head_dim) + + self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.size() + + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + if self.qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + + # equivalent to cu_seqlens in `flash_attn` + cu_seqlens = kwargs.get('cu_seqlens', None) + + seqlen_offset, max_seqlen = 0, q_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if past_key_values is not None: + cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 + k_cached, v_cached = past_key_values.update( + attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), + layer_idx=self.layer_idx, + offset=q_len, + cache_kwargs=dict(window_size=self.window_size) + )['attn_state'] + if cache_has_content: + k, v = k_cached, v_cached + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + if flash_attn_func is None: + raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + + # Contains at least one padding token in the sequence + if self.attn_impl == "flash_attn": + if attention_mask is not None: + q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + o = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = pad_input(o, indices_q, batch_size, q_len) + elif cu_seqlens is not None: + o = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ).unsqueeze(0) + else: + o = flash_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + elif self.attn_impl == "parallel_attn": + o = parallel_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens) + elif self.attn_impl == "parallel_rectified_attn": + o = parallel_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens) + elif self.attn_impl == "parallel_softpick_attn": + o = parallel_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens) + elif self.attn_impl == "naive_attn": + o, attentions = naive_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens) + elif self.attn_impl == "naive_rectified_attn": + o, attentions = naive_rectified_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens) + elif self.attn_impl == "naive_softpick_attn": + o, attentions = naive_softpick_attn(q, k, v, scale=self.head_dim**-0.5, cu_seqlens=cu_seqlens) + else: + raise ValueError(f"Unknown attention implementation: {self.attn_impl}") + + o = o.reshape(batch_size, q_len, -1) + o = self.o_proj(o) + + if not output_attentions or "parallel" in self.attn_impl or "flash" in self.attn_impl: + attentions = None + + return o, attentions, past_key_values + + def _upad_input(self, q, k, v, attention_mask, q_len): + batch_size, seq_len, num_key_value_heads, head_dim = k.shape + cache_mask = attention_mask[:, -seq_len:] + seqlens = cache_mask.sum(-1, dtype=torch.int32) + indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten() + max_seqlen_k = seqlens.max().item() + cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + + k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) + v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) + if q_len == seq_len: + q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_q = max_seqlen_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_q = 1 + # There is a memcpy here, that is very bad. + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) + indices_q = cu_seqlens_q[:-1] + q = q.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -q_len:] + q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) + + return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) diff --git a/fla/layers/based.py b/fla/layers/based.py new file mode 100644 index 0000000000000000000000000000000000000000..64a12364d246d7eeedcf40ac39ea21a6a4312971 --- /dev/null +++ b/fla/layers/based.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +Linear attention in Based. +https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py +""" + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules.feature_map import TaylorFeatureMap +from fla.ops.based import parallel_based +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn + + +class BasedLinearAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + feature_dim: int = 16, + num_key_value_heads: int = 12, + num_heads: int = 12, + feature_name: str = "taylor_exp", + eps: float = 1e-12, + causal: bool = True, + mode: str = "parallel", + ): + super().__init__() + + self.hidden_size = hidden_size + self.mode = mode + self.feature_name = feature_name + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_key_value_heads + assert self.hidden_size % self.head_dim == 0 + self.causal = causal + + self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Identity() + self.feature_map = TaylorFeatureMap(feature_dim) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v]) + if mode == "fused_chunk": + q, k = self.feature_map(q), self.feature_map(k) + o, _ = fused_chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False) + elif mode == 'chunk': + q, k = self.feature_map(q), self.feature_map(k) + o, _ = chunk_linear_attn(q, k, v, normalize=True, scale=1, head_first=False) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_based(q, k, v, scale=1, use_norm=True, head_first=False) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + + def forward_reference(self, hidden_states: torch.Tensor, filters: torch.Tensor = None, *args, **kwargs): + """ + x (torch.Tensor): tensor of shape (b, d, t) + y (torch.Tensor): tensor of shape (b, d, t) + """ + # hidden_states = hidden_states.transpose(1, 2) + b, t, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = q.view(b, t, self.num_heads, self.feature_dim).transpose(1, 2) + k = k.view(b, t, self.num_key_value_heads, self.feature_dim).transpose(1, 2) + v = v.view(b, t, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h t d -> b t (h d)') + y = self.o_proj(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) diff --git a/fla/layers/bitattn.py b/fla/layers/bitattn.py new file mode 100644 index 0000000000000000000000000000000000000000..f797b164963161f7b292b844e06e16527ea0817a --- /dev/null +++ b/fla/layers/bitattn.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.modules import RotaryEmbedding +from fla.modules.fused_bitlinear import FusedBitLinear + +if TYPE_CHECKING: + from fla.models.utils import Cache + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + +logger = logging.get_logger(__name__) + + +class BitAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + window_size: Optional[int] = None, + rope_theta: Optional[float] = 10000., + max_position_embeddings: Optional[int] = None, + norm_eps: float = 1e-5, + layer_idx: int = None + ): + super().__init__() + + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.hidden_size = hidden_size + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False) + self.k_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False) + self.v_proj = FusedBitLinear(self.hidden_size, self.kv_dim, bias=False) + self.o_proj = FusedBitLinear(self.hidden_size, self.hidden_size, bias=False) + + self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, q_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + + # equivalent to cu_seqlens in `flash_attn` + cu_seqlens = kwargs.get('cu_seqlens', None) + + seqlen_offset, max_seqlen = 0, q_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if past_key_values is not None: + cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 + k_cached, v_cached = past_key_values.update( + attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), + layer_idx=self.layer_idx, + offset=q_len, + cache_kwargs=dict(window_size=self.window_size) + )['attn_state'] + if cache_has_content: + k, v = k_cached, v_cached + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + if flash_attn_func is None: + raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") + + # Contains at least one padding token in the sequence + if attention_mask is not None: + q, k, v, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q, k, v, attention_mask, q_len) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_q, max_seqlen_k = max_seq_lens + o = flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = pad_input(o, indices_q, batch_size, q_len) + elif cu_seqlens is not None: + o = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ).unsqueeze(0) + else: + o = flash_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) + ) + o = o.reshape(batch_size, q_len, -1) + o = self.o_proj(o) + + if not output_attentions: + attentions = None + + return o, attentions, past_key_values + + def _upad_input(self, q, k, v, attention_mask, q_len): + batch_size, seq_len, num_key_value_heads, head_dim = k.shape + cache_mask = attention_mask[:, -seq_len:] + seqlens = cache_mask.sum(-1, dtype=torch.int32) + indices_k = torch.nonzero(cache_mask.flatten(), as_tuple=False).flatten() + max_seqlen_k = seqlens.max().item() + cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) + + k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) + v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) + if q_len == seq_len: + q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_q = max_seqlen_k + indices_q = indices_k + elif q_len == 1: + max_seqlen_q = 1 + # There is a memcpy here, that is very bad. + cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) + indices_q = cu_seqlens_q[:-1] + q = q.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -q_len:] + q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) + + return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) diff --git a/fla/layers/delta_net.py b/fla/layers/delta_net.py new file mode 100644 index 0000000000000000000000000000000000000000..5190902eec586f0bb296995163eee0a98fe2cbfe --- /dev/null +++ b/fla/layers/delta_net.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.ops.delta_rule import chunk_delta_rule, fused_recurrent_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +class DeltaNet(nn.Module): + r""" + The layer implementaion for [Parallelizing Linear Transformers with the Delta Rule over Sequence Length](https://arxiv.org/abs/2406.06484). # noqa: + DeltaNet was originally proposed in [Linear Transformers Are Secretly Fast Weight Programmers](https://arxiv.org/abs/2102.11174). # noqa + + Args: + mode (str, Optional): + Which DeltaNet kernel to use. + Currently available: `chunk`, `fused_recurrent`, and `fused_chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `False`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + allow_neg_eigval (bool, Optional): + Allow negative eigenvalues. Default: `False`. If set to `True`, the beta will be multiplied by 2. + See reference: [Unlocking State-Tracking in Linear RNNs Through Negative Eigenvalues](https://arxiv.org/abs/2411.12537) + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + qk_activation (str, Optional): + The activation function for the query and key. Default: `silu`. + qk_norm (str, Optional): + The normalization method for the query and key. Default: `l2`. + """ + + def __init__( + self, + mode: str = 'chunk', + d_model: int = None, + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 1.0, + num_heads: int = 4, + use_beta: bool = True, + use_gate: bool = False, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + allow_neg_eigval: bool = False, + layer_idx: int = None, + qk_activation: str = 'silu', + qk_norm: str = 'l2', + norm_eps: float = 1e-5, + **kwargs + ) -> DeltaNet: + super().__init__() + + self.mode = mode + self.qk_activation = qk_activation + self.qk_norm = qk_norm + + assert self.qk_activation in ['silu', 'relu', 'elu', 'identity'] + assert self.qk_norm in ['l2', 'sum'] + + if d_model is not None: + hidden_size = d_model + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.allow_neg_eigval = allow_neg_eigval + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.layer_idx = layer_idx + + self.silu = nn.SiLU() + if mode == 'fused_chunk': + raise NotImplementedError("fused_chunk_delta_rule is now deprecated. Please use `chunk_delta_rule` instead.") + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + self.use_beta = use_beta + if self.use_beta: + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' if qk_activation == 'silu' else None + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' if qk_activation == 'silu' else None + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # change to inference mode. + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + if self.qk_activation == 'silu': + q, k = self.silu(q), self.silu(k) + v = self.silu(self.v_proj(hidden_states)) + + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + if self.qk_activation != 'silu': + if self.qk_activation == 'relu': + q, k = q.relu(), k.relu() + elif self.qk_activation == 'elu': + q, k = elu_p1(q), elu_p1(k) + elif self.qk_activation == 'identity': + pass + else: + raise NotImplementedError + + if self.qk_norm == 'sum': + q = sum_norm(q).to(q) + k = sum_norm(k).to(k) + + if self.use_beta: + beta = self.b_proj(hidden_states).sigmoid() + else: + beta = q.new_ones(q.shape[0], q.shape[1], q.shape[2]) + + if self.allow_neg_eigval: + beta = beta * 2. + + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_delta_rule( + q=q, + k=k, + v=v, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False, + use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False + ) + elif mode == 'chunk': + o, recurrent_state = chunk_delta_rule( + q=q, + k=k, + v=v, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False, + use_qk_l2norm_in_kernel=True if self.qk_norm == 'l2' else False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + + return o, None, past_key_values diff --git a/fla/layers/forgetting_attn.py b/fla/layers/forgetting_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2ddc30ef531beedefd0f65836f4148771bcb7d --- /dev/null +++ b/fla/layers/forgetting_attn.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange +from transformers.utils import logging + +from fla.modules import GroupNorm +from fla.ops.forgetting_attn.parallel import parallel_forgetting_attn + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +logger = logging.get_logger(__name__) + + +class ForgettingAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 32, + num_kv_heads: Optional[int] = None, + qkv_bias: bool = False, + qk_norm: bool = False, + window_size: Optional[int] = None, + use_output_gate: bool = False, + layer_idx: int = None + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = self.hidden_size // self.num_heads + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + self.qk_norm = qk_norm + + self.window_size = window_size + self.use_output_gate = use_output_gate + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.f_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True) + + if use_output_gate: + self.g_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + if qk_norm: + self.q_norm = GroupNorm( + num_groups=self.num_heads, + hidden_size=self.hidden_size, + is_rms_norm=True, + ) + self.k_norm = GroupNorm( + num_groups=self.num_kv_heads, + hidden_size=self.kv_dim, + is_rms_norm=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + cu_seqlens = kwargs.get('cu_seqlens', None) + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + f = F.logsigmoid(self.f_proj(hidden_states).float()) + if self.qk_norm: + q, k = self.q_norm(q), self.k_norm(k) + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + o = parallel_forgetting_attn(q, k, v, f, cu_seqlens=cu_seqlens) + o = rearrange(o, '... h d -> ... (h d)') + if self.use_output_gate: + o = self.g_proj(hidden_states).sigmoid() * o + o = self.o_proj(o) + + return o, None, past_key_values diff --git a/fla/layers/gated_deltanet.py b/fla/layers/gated_deltanet.py new file mode 100644 index 0000000000000000000000000000000000000000..01bc3385cb7d40e35aee2fc693387d1bf1496c22 --- /dev/null +++ b/fla/layers/gated_deltanet.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange +from torch.nn import functional as F + +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +@torch.compile +def elu_p1(x): + return (F.elu(x, 1., False) + 1.).to(x) + + +@torch.compile +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +class GatedDeltaNet(nn.Module): + """ + The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa + + Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. + + Parameter alloation when use_gate=True: + - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each + - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each + - Others are ignorably small. + - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size + NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. + + Parameter allocation when use_gate=False: + - 1 * hidden_size * hidden_size for the q_proj and k_proj each + - 2 * hidden_size * hidden_size for the v_proj and o_proj each + - Others are ignorably small. + - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size + + Args: + hidden_size (int, Optional): + The hidden size of the input. Default: 2048. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + head_dim (int, Optional): + The dimension of each head. Default: 256. + num_heads (int, Optional): + The number of heads. Default: 4. + mode (str, Optional): + Which Gated DeltaNet kernel to use. + Currently available: `chunk` and `fused_recurrent`. + Default: `chunk`. + use_beta (bool, Optional): + Whether to use beta. Default: `True`. + use_gate (bool, Optional): + Whether to use output gate. Default: `True`. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `True`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + layer_idx (int, Optional): + The index of the layer. Default: None. + norm_eps (float, Optional): + The epsilon value for the normalization layer. Default: 1e-5. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + mode: str = 'chunk', + use_gate: bool = True, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int = None, + norm_eps: float = 1e-5, + **kwargs + ) -> GatedDeltaNet: + super().__init__() + + self.mode = mode + + self.hidden_size = hidden_size + self.expand_v = expand_v + + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.head_dim = head_dim + self.num_heads = num_heads + + self.key_dim = int(self.num_heads * self.head_dim) + self.value_dim = int(self.key_dim * self.expand_v) + self.head_k_dim = head_dim + self.head_v_dim = int(head_dim * self.expand_v) + self.layer_idx = layer_idx + + # Consistency check: Ensure expand_v produces integer values + if not math.isclose(self.key_dim * expand_v, self.value_dim, rel_tol=1e-5): + raise ValueError( + f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. " + f"Resulting value_dim would be {self.key_dim * expand_v}, which is invalid for nn.Linear." + ) + if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5): + raise ValueError( + f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. " + f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated." + ) + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + # hard coded for now + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + # Just to be explicit. Without this we already don't put wd on dt_bias because of the check + # name.endswith("bias") in param_grouping.py + self.dt_bias._no_weight_decay = True + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.k_conv1d = ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation='silu' + ) + self.v_conv1d = ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation='silu' + ) + else: + raise UserWarning( + "ShortConvolution is crucial to the performance. " + "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." + ) + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + if self.training: + assert mode == 'chunk', "Only chunk mode is supported in training." + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = F.silu(self.q_proj(hidden_states)) + k = F.silu(self.k_proj(hidden_states)) + v = F.silu(self.v_proj(hidden_states)) + + q, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (q, k)) + v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) + beta = self.b_proj(hidden_states).sigmoid() + g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) + + # dealing with padding + if attention_mask is not None: + beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) + g = g.mul(attention_mask[:, -g.shape[-2]:, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False, + use_qk_l2norm_in_kernel=True + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False, + use_qk_l2norm_in_kernel=True + ) + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + if self.use_gate: + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=self.head_v_dim) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, 'b t h d -> b t (h d)') + o = self.o_proj(o) + + return o, None, past_key_values diff --git a/fla/layers/gated_deltaproduct.py b/fla/layers/gated_deltaproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..77cad9a1143614ebb55233030f63f4c8b80c5a53 --- /dev/null +++ b/fla/layers/gated_deltaproduct.py @@ -0,0 +1,351 @@ +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution +from fla.ops.delta_rule import chunk_delta_rule +from fla.ops.gated_delta_rule import chunk_gated_delta_rule + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +def elu_p1(x): + return (F.elu(x, 1.0, False) + 1.0).to(x) + + +def sum_norm(x): + return (x / x.sum(-1, keepdim=True)).to(x) + + +def interleave_multiple_sequences(*sequences): + """ + Interleave multiple sequences together. + For example, with sequences [A1, A2], [B1, B2], [C1, C2], + returns [A1, B1, C1, A2, B2, C2] + """ + if isinstance(sequences[0], (list, tuple)): + sequences = sequences[0] + + if len(sequences) == 1: + return sequences[0] + + # All sequences should have the same shape + assert all(s.shape == sequences[0].shape for s in sequences) + + # Get the original shape + batch_size, seq_len, *rest = sequences[0].shape + + # Stack sequences along a new dimension + stacked = torch.stack(sequences, dim=2) + + # Reshape to interleave + reshaped = stacked.view(batch_size, seq_len * len(sequences), *rest) + + return reshaped + + +class GatedDeltaProduct(nn.Module): + """ + Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations. + """ + + def __init__( + self, + hidden_size: int = 2048, + expand_v: float = 2, + head_dim: int = 256, + num_heads: int = 6, + num_householder: int = 2, # New parameter for number of householder transformations + mode: str = "chunk", + use_gate: bool = True, + use_forget_gate: bool = True, # when true Gated DeltaProduct, when false DeltaProduct + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + layer_idx: int | None = None, + norm_eps: float = 1e-5, + allow_neg_eigval: bool = False, # when true (Gated) DeltaProduct [-1, 1], when false (Gated) DeltaProduct [0, 1] + **kwargs, + ) -> None: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_v = expand_v + self.use_gate = use_gate + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.head_dim = head_dim + self.num_heads = num_heads + self.num_householder = num_householder + self.allow_neg_eigval = allow_neg_eigval + self.use_forget_gate = use_forget_gate + self.key_dim = self.num_heads * self.head_dim + self.value_dim = int(self.key_dim * self.expand_v) + self.head_qk_dim = head_dim + self.head_v_dim = int(head_dim * self.expand_v) + self.layer_idx = layer_idx + self.silu = nn.SiLU() + assert mode in ["chunk", "fused_recurrent"], f"Not supported mode `{mode}`." + # Create multiple projection layers for each householder transformation + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + + self.k_projs = nn.ModuleList( + [ + nn.Linear(hidden_size, self.key_dim, bias=False) + for _ in range(num_householder) + ] + ) + self.v_projs = nn.ModuleList( + [ + nn.Linear(hidden_size, self.value_dim, bias=False) + for _ in range(num_householder) + ] + ) + self.b_projs = nn.ModuleList( + [ + nn.Linear(hidden_size, self.num_heads, bias=False) + for _ in range(num_householder) + ] + ) + if use_short_conv: + self.q_conv1ds = nn.ModuleList( + [ + ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation="silu", + ) + for _ in range(num_householder) + ] + ) + self.k_conv1ds = nn.ModuleList( + [ + ShortConvolution( + hidden_size=self.key_dim, + kernel_size=conv_size, + activation="silu", + ) + for _ in range(num_householder) + ] + ) + self.v_conv1ds = nn.ModuleList( + [ + ShortConvolution( + hidden_size=self.value_dim, + kernel_size=conv_size, + activation="silu", + ) + for _ in range(num_householder) + ] + ) + + if self.use_forget_gate: + self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) + A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) + A_log = torch.log(A) + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # Initialize dt parameters + dt_min = 0.001 + dt_max = 0.1 + dt_init_floor = 1e-4 + dt = torch.exp( + torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + + math.log(dt_min) + ) + dt = torch.clamp(dt, min=dt_init_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) + self.dt_bias = nn.Parameter(inv_dt) + self.dt_bias._no_weight_decay = True + + if use_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) + else: + self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.k_id = torch.nn.Identity() + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding)." + ) + + mode = ( + "chunk" # 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + ) + if self.training: + assert mode == "chunk", "Only chunk mode is supported in training." + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + # Process each householder transformation + ks, vs, betas = [], [], [] + conv_states = [] + + for i in range(self.num_householder): + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"][ + i + ] + conv_mask = ( + attention_mask[:, -hidden_states.shape[1]:] + if attention_mask is not None + else None + ) + + k, conv_state_k = self.k_conv1ds[i]( + x=self.k_projs[i](hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + ) + v, conv_state_v = self.v_conv1ds[i]( + x=self.v_projs[i](hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + ) + conv_states.append((conv_state_q, conv_state_k, conv_state_v)) + else: + k = self.silu(self.k_projs[i](hidden_states)) + v = self.silu(self.v_projs[i](hidden_states)) + + ks.append(k) + vs.append(v) + + beta = self.b_projs[i]( + hidden_states + ).sigmoid() # bs, sequence_length, num_heads + if attention_mask is not None: + beta = beta.mul(attention_mask[:, -hidden_states.shape[1]:, None]) + if self.allow_neg_eigval: + beta = beta * 2 + betas.append(beta) + + if self.use_short_conv: + q, conv_state_q = self.q_conv1ds[0]( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + ) + else: + q = self.silu(self.q_proj(hidden_states)) + q = interleave_multiple_sequences( + [torch.zeros_like(q)] * (self.num_householder - 1) + [q] + ) + # Interleave all sequences + k = interleave_multiple_sequences(ks) + v = interleave_multiple_sequences(vs) + beta = interleave_multiple_sequences(betas) + + q, k, v = ( + rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in (q, k, v) + ) + + recurrent_state = ( + last_state["recurrent_state"] if last_state is not None else None + ) + offsets = kwargs.get("offsets") + + if mode == "chunk": + if self.use_forget_gate: + g = -self.A_log.float().exp() * F.softplus( + self.a_proj(hidden_states).float() + self.dt_bias + ) + if attention_mask is not None: + g = g.mul(attention_mask[:, -g.shape[-2]:, None]) + + # Interleave g with zeros for non-first transformations + g = interleave_multiple_sequences( + [g] + [torch.zeros_like(g)] * (self.num_householder - 1) + ) + + o, recurrent_state = chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False, + use_qk_l2norm_in_kernel=True + ) + else: + o, recurrent_state = chunk_delta_rule( + q=q, + k=k, + v=v, + beta=beta, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=offsets, + head_first=False, + use_qk_l2norm_in_kernel=True + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + # Take every nth element for n householder transformations + o = o[:, self.num_householder - 1:: self.num_householder, :] + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=conv_states if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[2], + ) + + if self.use_gate: + g = rearrange( + self.g_proj(hidden_states), + "... (h d) -> ... h d", + h=self.num_heads, + ) + o = self.o_norm(o, g) + else: + o = self.o_norm(o) + o = rearrange(o, "b t h d -> b t (h d)") + o = self.o_proj(o) + + return o, None, past_key_values diff --git a/fla/layers/gsa.py b/fla/layers/gsa.py new file mode 100644 index 0000000000000000000000000000000000000000..dc22b1d392d9dc117ae91b92167c3db32e746732 --- /dev/null +++ b/fla/layers/gsa.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import RMSNorm, ShortConvolution +from fla.modules.feature_map import ReLUFeatureMap, SwishFeatureMap, T2RFeatureMap +from fla.modules.layernorm import rms_norm_linear +from fla.ops.gsa import chunk_gsa, fused_recurrent_gsa + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class GatedSlotAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1., + expand_v: float = 1., + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + num_slots: Optional[int] = None, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 8, + feature_map: str = 'swish', + use_output_gate: bool = False, + use_norm: bool = True, + layer_idx: Optional[int] = None, + scale: Optional[float] = 1., + **kwargs + ) -> GatedSlotAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.head_k_dim = self.key_dim // self.num_heads + self.head_v_dim = self.value_dim // self.num_heads + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.gate_logit_normalizer = gate_logit_normalizer + + self.use_output_gate = use_output_gate + self.use_norm = use_norm + self.scale = scale + + if num_slots is None: + num_slots = self.head_k_dim + self.num_slots = num_slots + + self.layer_idx = layer_idx + + if layer_idx is None: + warnings.warn( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.register_module('feature_map', None) + if feature_map == 'swish': + self.feature_map = SwishFeatureMap() + elif feature_map == 'relu': + self.feature_map = ReLUFeatureMap() + elif feature_map == 't2r': + self.feature_map = T2RFeatureMap(self.head_k_dim, self.head_k_dim) + else: + raise NotImplementedError(f"Feature map `{feature_map}` is not supported now.") + + self.q_proj = nn.Linear(self.hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.value_dim_per_group, bias=False) + self.f_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.num_slots, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.g_norm = RMSNorm(self.hidden_size, elementwise_affine, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + f = self.f_proj(hidden_states) + + q = rearrange(q, 'b t (h d) -> b t h d', d=self.head_k_dim) + k = rearrange(k, 'b t (h d) -> b t h d', d=self.head_k_dim) + v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) + f = rearrange(f, 'b t (h m) -> b t h m', m=self.num_slots) + + if self.feature_map is not None: + q, k = map(lambda x: self.feature_map(x), (q, k)) + v = F.silu(v) + + f = F.logsigmoid(f) / self.gate_logit_normalizer + s = (1 - f.exp()).to(f.dtype) + # dealing with left-padding + if attention_mask is not None: + s = s.mul_(attention_mask[:, -s.shape[1]:, None, None]) + v = v.mul_(attention_mask[:, -v.shape[1]:, None, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gsa( + q=q, + k=k, + v=v, + s=s, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + scale=self.scale, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gsa( + q=q, + k=k, + v=v, + s=s, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + scale=self.scale, + cu_seqlens=cu_seqlens, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + o = rearrange(o, 'b t h d -> b t (h d)') + o = rms_norm_linear(F.silu(o), self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias) + return o, None, past_key_values + + def state_size(self, *args, **kwargs) -> int: + return 2 * self.num_slots * self.hidden_size diff --git a/fla/layers/hgrn.py b/fla/layers/hgrn.py new file mode 100644 index 0000000000000000000000000000000000000000..6dfa0f946f9b47ec2315605554e2362acef6cd56 --- /dev/null +++ b/fla/layers/hgrn.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# "Hierarchically Gated Recurrent Neural Network for Sequence Modeling" [https://arxiv.org/abs/2311.04823] + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fla.modules import FusedRMSNormGated, ShortConvolution +from fla.modules.activations import swiglu +from fla.ops.hgrn import chunk_hgrn, fused_recurrent_hgrn + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class HGRNAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_ratio: Optional[int] = 1, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> HGRNAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_ratio = expand_ratio + self.input_dim = int(hidden_size * expand_ratio) + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + self.g_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + self.f_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + + self.g_norm = FusedRMSNormGated( + hidden_size=self.input_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if not self.training and hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_i, conv_state_f = None, None + if last_state is not None: + conv_state_i, conv_state_f = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + i, conv_state_i = self.i_conv1d( + x=self.i_proj(hidden_states), + mask=conv_mask, + cache=conv_state_i, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + f, conv_state_f = self.f_conv1d( + x=self.f_proj(hidden_states), + mask=conv_mask, + cache=conv_state_f, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + i = self.i_proj(hidden_states) + f = self.f_proj(hidden_states) + + # the lower bound for the first layer is zero + if lower_bound is None or self.layer_idx == 0: + i, f = swiglu(i, 1 - f.sigmoid()), F.logsigmoid(f) + else: + g = lower_bound + (1 - lower_bound) * f.sigmoid() + i, f = swiglu(i, 1 - g), g.log() + + # dealing with left-padding + if attention_mask is not None: + i = i.mul_(attention_mask[:, -i.shape[-2]:, None]) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + if cu_seqlens is not None: + raise NotImplementedError("Chunk mode does not support variable-length sequences.") + o, recurrent_state = chunk_hgrn( + x=i, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_hgrn( + x=i, + g=f, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_i, conv_state_f) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=i.shape[2] + ) + + o = self.g_norm(o, self.g_proj(hidden_states)) + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.hidden_size + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla/layers/hgrn2.py b/fla/layers/hgrn2.py new file mode 100644 index 0000000000000000000000000000000000000000..13fb04253790113c770faa93101b890eb63e4a19 --- /dev/null +++ b/fla/layers/hgrn2.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# "HGRN2: Gated Linear RNNs with State Expansion"[https://arxiv.org/abs/2404.07904] + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import RMSNorm, ShortConvolution +from fla.modules.activations import swish +from fla.modules.layernorm import rms_norm_linear +from fla.ops.gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class HGRN2Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> HGRN2Attention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + + if expand_ratio is None and num_heads is not None: + expand_ratio = hidden_size // num_heads + elif expand_ratio is not None and num_heads is None: + num_heads = hidden_size // expand_ratio + elif expand_ratio is None and num_heads is None: + raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") + self.num_heads = num_heads + self.expand_ratio = expand_ratio + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.forget_dim = int(self.num_heads * self.expand_ratio) + self.input_dim = hidden_size + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.forget_dim % num_heads == 0, f"forget dim must be divisible by num_heads of {num_heads}" + assert self.input_dim % num_heads == 0, f"input dim must be divisible by num_heads of {num_heads}" + + self.head_f_dim = self.expand_ratio + self.head_i_dim = self.hidden_size // num_heads + + self.q_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) + self.f_proj = nn.Linear(hidden_size, self.forget_dim, bias=False) + self.i_proj = nn.Linear(hidden_size, self.input_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None) + self.f_conv1d = ShortConvolution(self.forget_dim, conv_size, activation=None) + self.i_conv1d = ShortConvolution(self.input_dim, conv_size, activation=None) + + self.g_norm = RMSNorm(hidden_size=self.hidden_size, elementwise_affine=elementwise_affine, eps=norm_eps) + self.o_proj = nn.Linear(self.input_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + lower_bound: Optional[torch.Tensor] = None, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_f, conv_state_i = None, None, None + if last_state is not None: + conv_state_q, conv_state_f, conv_state_i = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + f, conv_state_f = self.f_conv1d( + x=self.f_proj(hidden_states), + mask=conv_mask, + cache=conv_state_f, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + i, conv_state_i = self.i_conv1d( + x=self.i_proj(hidden_states), + mask=conv_mask, + cache=conv_state_i, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + f = self.f_proj(hidden_states) + i = self.i_proj(hidden_states) + + # dealing with left-padding + if attention_mask is not None: + i = i.mul_(attention_mask[:, -i.shape[-2]:, None]) + + q = swish(q) + + # improve precision + f = f.float() + + # the lower bound for the first layer is zero + if lower_bound is None or self.layer_idx == 0: + k, g = 1 - f.sigmoid(), F.logsigmoid(f) + else: + g = lower_bound + (1 - lower_bound) * f.sigmoid() + k, g = 1 - g, g.log() + + q, k, g = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k.to(i), g)) + i = rearrange(i, '... (h d) -> ... h d', d=self.head_i_dim) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla( + q=q, + k=k, + v=i, + gk=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_gla( + q=q, + k=k, + v=i, + g=g, + initial_state=recurrent_state, + output_final_state=use_cache, + head_first=False + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gla( + q=q, + k=k, + v=i, + g=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_f, conv_state_i) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + o = rearrange(o, '... h d -> ... (h d)') + o = rms_norm_linear(o, self.g_norm.weight, self.g_norm.bias, self.o_proj.weight, self.o_proj.bias) + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.forget_dim * self.head_i_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla/layers/lightnet.py b/fla/layers/lightnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f3c01a88938a302502f12a54abeb159cf4db58 --- /dev/null +++ b/fla/layers/lightnet.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ["You Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNet"](https://arxiv.org/abs/2405.21022) + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from fla.modules import FusedRMSNormGated, ShortConvolution +from fla.modules.fused_norm_gate import rms_norm_swish_gate_linear +from fla.ops.gla import chunk_gla, fused_recurrent_gla + +if TYPE_CHECKING: + from transformers.processing_utils import Unpack + + from fla.models.utils import Cache + + +class LightNetAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + num_heads: Optional[int] = None, + expand_ratio: Optional[int] = 128, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + gate_low_rank_dim: int = 128, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None + ) -> LightNetAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + + if expand_ratio is None and num_heads is not None: + expand_ratio = hidden_size // num_heads + elif expand_ratio is not None and num_heads is None: + num_heads = hidden_size // expand_ratio + elif expand_ratio is None and num_heads is None: + raise RuntimeError("One of `expand_ratio` or `num_heads` should be provided.") + self.num_heads = num_heads + self.expand_ratio = expand_ratio + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.key_dim = int(self.num_heads * self.expand_ratio) + self.value_dim = hidden_size + self.gate_low_rank_dim = gate_low_rank_dim + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_chunk'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_f_dim = self.expand_ratio + self.head_i_dim = self.hidden_size // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None) + self.k_conv1d = ShortConvolution(self.key_dim, conv_size, activation=None) + self.v_conv1d = ShortConvolution(self.value_dim, conv_size, activation=None) + + self.g_proj = nn.Sequential( + nn.Linear(hidden_size, gate_low_rank_dim, bias=False), + nn.Linear(gate_low_rank_dim, hidden_size, bias=False) + ) + self.g_norm = FusedRMSNormGated( + hidden_size=hidden_size, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs: Unpack[Dict] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + + q = F.silu(q) + q, k = map(lambda x: rearrange(x, '... (h d) -> ... h d', d=self.head_f_dim), (q, k)) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_i_dim) + # TODO: this 2 steps took huge amount of time, which should be optimized + z = k.float().logcumsumexp(1) + + if cu_seqlens is not None: + raise NotImplementedError("LightNet does not support variable-length sequences for now.") + k, g = torch.exp(k - z).to(k.dtype), (torch.cat((z[:, :1], z[:, :-1]), 1) - z).to(k.dtype) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_gla( + q=q, + k=k, + v=v, + gk=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'chunk': + o, recurrent_state = chunk_gla( + q=q, + k=k, + v=v, + g=g, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + o = rms_norm_swish_gate_linear( + rearrange(o, 'b t h d -> b t (h d)'), + self.g_proj(hidden_states), + self.g_norm.weight, + self.g_norm.bias, + self.o_proj.weight, + self.o_proj.bias + ) + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_i_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla/layers/linear_attn.py b/fla/layers/linear_attn.py new file mode 100644 index 0000000000000000000000000000000000000000..54c570b7b5759f0a5d7e9653e87b1ab07b3546da --- /dev/null +++ b/fla/layers/linear_attn.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from fla.modules import RMSNorm +from fla.modules.feature_map import DPFPFeatureMap, HadamardFeatureMap, HedgehogFeatureMap, T2RFeatureMap +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn + + +class LinearAttention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: str = 1024, + expand_k: int = 1.0, + expand_v: int = 1.0, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: str = 'elementwise_product', + tie_feature_map_qk: bool = False, + output_norm: str = 'rmsnorm', + norm_q: bool = False, + norm_k: bool = False, + do_feature_map_norm: bool = False, + elementwise_affine: bool = True, + norm_eps: float = 1e-5, + **kwargs + ): + super().__init__() + + self.hidden_size = hidden_size + self.mode = mode + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + + assert mode in ['chunk', 'fused_chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + self.do_feature_map_norm = do_feature_map_norm + + if feature_map == 'hedgehog': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim) + else: + self.feature_map_q = HedgehogFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = HedgehogFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 't2r': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim) + else: + self.feature_map_q = T2RFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = T2RFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 'elementwise_product': + if tie_feature_map_qk: + self.feature_map_q = self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim) + else: + self.feature_map_q = HadamardFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = HadamardFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 'dpfp': + self.feature_map_q = DPFPFeatureMap(head_dim=self.head_k_dim) + self.feature_map_k = DPFPFeatureMap(head_dim=self.head_k_dim) + + elif feature_map == 'elu': + def elu(x): + return F.elu(x) + 1 + self.feature_map_q = elu + self.feature_map_k = elu + + elif feature_map == 'relu': + self.feature_map_q = nn.ReLU() + self.feature_map_k = nn.ReLU() + + elif feature_map == 'identity': + self.feature_map_q = nn.Identity() + self.feature_map_k = nn.Identity() + else: + raise NotImplementedError(f"Not supported feature map `{feature_map}`.") + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + + if output_norm == 'rmsnorm': + self.norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps) + elif output_norm == 'identity': + self.norm = nn.Identity() + else: + raise NotImplementedError(f"Not supported output norm `{output_norm}`.") + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.norm_q = norm_q + self.norm_k = norm_k + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs + ) -> torch.Tensor: + mode = self.mode + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) + if self.num_kv_groups > 1: + k = repeat(k, '... (h d) -> ... (h g) d', d=self.head_k_dim, g=self.num_kv_groups) + v = repeat(v, '... (h d) -> ... (h g) d', d=self.head_v_dim, g=self.num_kv_groups) + else: + k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_v_dim) + + q = self.feature_map_q(q) + k = self.feature_map_k(k) + + if self.norm_q: + q = q / (q.sum(-1, True) + 1e-4) + if self.norm_k: + k = k / (k.sum(-1, True) + 1e-4) + + if mode == 'chunk': + o, final_state = chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=self.do_feature_map_norm, + head_first=False + ) + elif mode == 'fused_chunk': + o, final_state = fused_chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=self.do_feature_map_norm, + ) + elif mode == 'fused_recurrent': + o, final_state = fused_recurrent_linear_attn( + q=q, + k=k, + v=v, + normalize=self.do_feature_map_norm, + ) + else: + raise NotImplementedError + o = self.norm(o) + o = rearrange(o, '... h d -> ... (h d)') + o = self.o_proj(o) + return o diff --git a/fla/layers/multiscale_retention.py b/fla/layers/multiscale_retention.py new file mode 100644 index 0000000000000000000000000000000000000000..e5364b356b245221b78c32c41eba1b4ca1cc5622 --- /dev/null +++ b/fla/layers/multiscale_retention.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from transformers.activations import ACT2FN + +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.modules.rotary import RotaryEmbedding +from fla.ops.retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class MultiScaleRetention(nn.Module): + r""" + The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa + + Args: + mode (str, Optional): + Which Retention kernel to use. + Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 2.0. + num_heads (int, Optional): + The number of heads. Default: 8. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + use_output_gate (bool, Optional): + Whether to use output gate. Default: `True`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1.0, + expand_v: float = 2.0, + num_heads: int = 8, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = False, + conv_size: int = 4, + conv_bias: bool = False, + use_output_gate: bool = True, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + fuse_norm: bool = True, + layer_idx: int = None, + **kwargs + ) -> MultiScaleRetention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + self.use_output_gate = use_output_gate + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + if self.use_output_gate: + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + if gate_fn == 'swish' and fuse_norm and use_output_gate: + self.g_norm_swish_gate = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.gate_fn = ACT2FN[gate_fn] + + # TODO: fix this issue + # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180 + # Ideally, we would want to support arbitrary d_head_qk + assert self.head_k_dim <= 256, "head_k_dim must be less than or equal to 256" + self.rotary = RotaryEmbedding(dim=self.head_k_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + q = rearrange(q, '... (h d) -> ... h d', d=self.head_k_dim) + k = rearrange(k, '... (h d) -> ... h d', d=self.head_k_dim) + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + + seqlen_offset, max_seqlen = 0, q.shape[1] + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if self.num_kv_groups > 1: + k = repeat(k, 'b t h d -> b t (h g) d', g=self.num_kv_groups) + v = repeat(v, 'b t (h d) -> b t (h g) d', d=self.head_v_dim, g=self.num_kv_groups) + else: + v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_retention( + q=q, + k=k, + v=v, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'fused_chunk': + o, recurrent_state = fused_chunk_retention( + q=q, + k=k, + v=v, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'parallel': + o, recurrent_state = parallel_retention( + q=q, + k=k, + v=v, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_retention( + q=q, + k=k, + v=v, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + if self.use_output_gate: + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, 'b t (h d) -> b t h d', d=self.head_v_dim) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b t h d -> b t (h d)') + else: + o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)') + o = o * self.gate_fn(g) + else: + o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)') + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_v_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla/layers/nsa.py b/fla/layers/nsa.py new file mode 100644 index 0000000000000000000000000000000000000000..0a65b458d6c000bd52c56bd4c93c4357b5c3844f --- /dev/null +++ b/fla/layers/nsa.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple, Union + +import torch +import torch.nn as nn +from einops import rearrange +from transformers.utils import logging + +from fla.modules import RotaryEmbedding +from fla.ops.nsa.parallel import parallel_nsa + +if TYPE_CHECKING: + from fla.models.utils import Cache + +logger = logging.get_logger(__name__) + + +class NativeSparseAttention(nn.Module): + + def __init__( + self, + hidden_size: int = 2048, + num_heads: int = 64, + num_kv_heads: Optional[int] = 4, + head_dim: int = 64, + qkv_bias: bool = False, + block_size: Optional[int] = 64, + block_counts: Optional[Union[torch.LongTensor, int]] = 16, + window_size: Optional[int] = 512, + rope_theta: Optional[float] = 10000., + max_position_embeddings: Optional[int] = None, + layer_idx: int = None + ): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + if num_kv_heads is None: + self.num_kv_heads = self.num_heads + else: + self.num_kv_heads = num_kv_heads + self.num_kv_groups = num_heads // self.num_kv_heads + self.head_dim = head_dim + self.kv_dim = self.num_kv_heads * self.head_dim + self.qkv_bias = qkv_bias + + self.block_size = block_size + self.block_counts = block_counts + self.window_size = window_size + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.qkv_bias) + self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=self.qkv_bias) + self.g_proj = nn.Linear(self.hidden_size, self.num_heads * 3, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, seq_len, _ = hidden_states.size() + + q = rearrange(self.q_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + k = rearrange(self.k_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(self.v_proj(hidden_states), '... (h d) -> ... h d', d=self.head_dim) + g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', d=3) + g_cmp, g_slc, g_swa = g.sigmoid().unbind(-1) + + cu_seqlens = kwargs.get('cu_seqlens', None) + + seqlen_offset, max_seqlen = 0, seq_len + if past_key_values is not None: + seqlen_offset = past_key_values.get_seq_length(self.layer_idx) + max_seqlen = q.shape[1] + seqlen_offset + + if attention_mask is not None: + # to deliminate the offsets of padding tokens + seqlen_offset = seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1] + max_seqlen = q.shape[1] + max(seqlen_offset) + + if self.max_position_embeddings is not None: + max_seqlen = max(max_seqlen, self.max_position_embeddings) + q, k = self.rotary(q, k, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen, cu_seqlens=cu_seqlens) + + if past_key_values is not None: + cache_has_content = past_key_values.get_seq_length(self.layer_idx) > 0 + k_cached, v_cached = past_key_values.update( + attn_state=(k.flatten(-2, -1), v.flatten(-2, -1)), + layer_idx=self.layer_idx, + offset=seq_len, + cache_kwargs=dict(window_size=self.window_size) + )['attn_state'] + if cache_has_content: + k, v = k_cached, v_cached + k = rearrange(k, '... (h d) -> ... h d', d=self.head_dim) + v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim) + + o = parallel_nsa( + q=q, + k=k, + v=v, + g_cmp=g_cmp, + g_slc=g_slc, + g_swa=g_swa, + block_size=self.block_size, + block_counts=self.block_counts, + window_size=self.window_size, + cu_seqlens=cu_seqlens, + head_first=False + ) + o = o.reshape(batch_size, seq_len, -1) + o = self.o_proj(o) + + if not output_attentions: + attentions = None + + return o, attentions, past_key_values diff --git a/fla/layers/rebased.py b/fla/layers/rebased.py new file mode 100644 index 0000000000000000000000000000000000000000..47c2e17dfe088cc3fda45c3d3aa3605b8fdb1e73 --- /dev/null +++ b/fla/layers/rebased.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +https://github.com/corl-team/rebased/blob/main/flash_linear_attention/fla/layers/rebased_fast.py +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules.feature_map import RebasedFeatureMap +from fla.ops.linear_attn import chunk_linear_attn, fused_chunk_linear_attn +from fla.ops.rebased import parallel_rebased + + +class ReBasedLinearAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + l_max: int = 2048, + feature_dim: int = 16, + num_key_value_heads: int = 16, + num_heads: int = 16, + use_gamma: Optional[bool] = True, + use_beta: Optional[bool] = True, + normalize: Optional[bool] = True, + causal: bool = True, + eps: float = 1e-5, + mode: str = "parallel", + layer_idx: Optional[int] = None, + **kwargs + ) -> ReBasedLinearAttention: + super().__init__() + self.hidden_size = hidden_size + self.l_max = l_max + self.mode = mode + assert self.mode in ["fused_chunk", "parallel", 'chunk'] + + self.feature_dim = feature_dim + self.num_key_value_heads = num_key_value_heads + self.num_heads = num_heads + self.head_dim = self.hidden_size // self.num_key_value_heads + self.use_gamma = use_gamma + self.use_beta = use_beta + self.normalize = normalize + self.causal = causal + self.eps = eps + self.mode = mode + self.layer_idx = layer_idx + + self.feature_map = RebasedFeatureMap(self.feature_dim, use_gamma, use_beta, normalize) + self.q_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.feature_dim * self.num_heads, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.dropout = nn.Identity() + + def forward(self, hidden_states: torch.Tensor, **kwargs): + mode = self.mode + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + q, k, v = map(lambda x: rearrange(x, "... (h d) -> ... h d", d=self.head_dim), [q, k, v]) + q, k = self.feature_map(q, flatten=(mode != 'parallel')), self.feature_map(k, flatten=(mode != 'parallel')) + if mode == "fused_chunk": + o = fused_chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=True, + scale=1, + head_first=False + ) + elif mode == 'chunk': + o = chunk_linear_attn( + q=q, + k=k, + v=v, + normalize=True, + scale=1, + head_first=False + ) + elif mode == 'parallel': + assert q.shape[-1] <= 128 + o = parallel_rebased( + q=q, + k=k, + v=v, + eps=self.eps, + use_scale=True, + use_normalize=True, + head_first=False + ) + o = self.o_proj(o) + o = self.dropout(o) + return o + + # https://github.com/HazyResearch/zoology/blob/main/zoology/mixers/based.py#L119 + def forward_reference( + self, + hidden_states: torch.Tensor, + filters: torch.Tensor = None, + *args, + **kwargs + ): + """ + x (torch.Tensor): tensor of shape (b, d, t) + y (torch.Tensor): tensor of shape (b, d, t) + """ + b, t, _ = hidden_states.size() + q, k, v = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) + + q = q.view(b, t, -1, self.feature_dim).transpose(1, 2) + k = k.view(b, t, -1, self.feature_dim).transpose(1, 2) + v = v.view(b, t, -1, self.head_dim).transpose(1, 2) + + # Linear attention + q, k = self.feature_map(q), self.feature_map(k) + q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1) + + # Compute attention + if self.causal: + y = ((q * (k * v).cumsum(2)).sum(-1) / ((q * k.cumsum(2)).sum(-1) + self.eps)) + else: + y = ((q * (k * v).sum(2, True)).sum(-1) / ((q * k.sum(2, True)).sum(-1) + self.eps)) + y = rearrange(y, 'b h t d -> b t (h d)') + y = self.o_proj(y.to(hidden_states.dtype)) + y = self.dropout(y) + return y.to(hidden_states.dtype) diff --git a/fla/layers/rwkv6.py b/fla/layers/rwkv6.py new file mode 100644 index 0000000000000000000000000000000000000000..6d8f01660adb3c164ac4ffb5ce7cf26ecb6e85a2 --- /dev/null +++ b/fla/layers/rwkv6.py @@ -0,0 +1,307 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# "Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrence"[https://arxiv.org/abs/2404.05892] + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from einops import rearrange + +from fla.modules import GroupNorm +from fla.modules.activations import ACT2FN +from fla.ops.rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class RWKV6Attention(nn.Module): + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 0.5, + expand_v: float = 1.0, + num_heads: int = 4, + gate_fn: str = 'swish', + proj_low_rank_dim: int = 32, + gate_low_rank_dim: int = 64, + fuse_norm: bool = True, + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + layer_idx: int = None, + **kwargs + ) -> RWKV6Attention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.proj_low_rank_dim = proj_low_rank_dim + self.gate_low_rank_dim = gate_low_rank_dim + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.layer_idx = layer_idx + + assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.x_proj = nn.Sequential( + LerpLinear(hidden_size, proj_low_rank_dim * 5), + nn.Tanh(), + nn.Linear(proj_low_rank_dim * 5, hidden_size, bias=False) + ) + self.x_bias = nn.Parameter(torch.zeros(5, hidden_size)) + + self.r_proj = DDLerpLinear(hidden_size, self.key_dim) + self.w_proj = DDLerpLinear(hidden_size, self.key_dim, low_rank_dim=gate_low_rank_dim) + self.k_proj = DDLerpLinear(hidden_size, self.key_dim) + self.v_proj = DDLerpLinear(hidden_size, self.value_dim) + self.g_proj = DDLerpLinear(hidden_size, self.value_dim) + self.bonus = nn.Parameter(torch.zeros(num_heads, self.head_k_dim)) + + # TODO: fuse GroupNorm and output gate + self.g_norm = GroupNorm(self.num_heads, self.value_dim, elementwise_affine=elementwise_affine, bias=True, eps=norm_eps) + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + self.gate_fn = ACT2FN[gate_fn] + + self.apply(self._initialize_weights) + + def _initialize_weights(self, module: nn.Module): + if getattr(module, "_is_hf_initialized", False): + return + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) + if module.bias is not None: + nn.init.zeros_(module.bias) + if isinstance(module, nn.Parameter): + nn.init.xavier_uniform_(module, gain=2 ** -2.5) + module._is_hf_initialized = True + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + batch_size, seq_len, hidden_size = hidden_states.shape + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + if attention_mask is not None: + hidden_states = hidden_states.mul_(attention_mask[:, -hidden_states.shape[-2]:, None]) + if hidden_states.shape[1] == 1 and last_state is not None: + shifted = last_state['conv_state'].unsqueeze(1) + else: + shifted = self.time_shift(hidden_states) + if last_state is not None: + shifted[:, 0] = last_state['conv_state'] + + delta = shifted - hidden_states + x = self.x_proj[0](hidden_states, delta).view(batch_size, seq_len, -1, self.proj_low_rank_dim) + x = torch.einsum('b t n r, h n r-> b t n h', self.x_proj[1](x), self.x_proj[2].weight.view(hidden_size, 5, -1)) + + r, w, k, v, g = x.add_(self.x_bias).unbind(-2) + r = self.r_proj(hidden_states, r, delta) + w = self.w_proj(hidden_states, w, delta) + k = self.k_proj(hidden_states, k, delta) + v = self.v_proj(hidden_states, v, delta) + g = self.g_proj(hidden_states, g, delta) + + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + r, w, k = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', d=self.head_k_dim), (r, w, k)) + v = rearrange(v, 'b t (h d) -> b t h d', d=self.head_v_dim) + w = -torch.exp(w) + u = self.bonus + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + cu_seqlens = kwargs.get('cu_seqlens', None) + if mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_rwkv6( + r=r, + k=k, + v=v, + w=w, + u=u, + scale=1., + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'chunk': + o, recurrent_state = chunk_rwkv6( + q=r, + k=k, + v=v, + g=w, + u=u, + scale=1., + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=hidden_states[:, -1], + layer_idx=self.layer_idx, + offset=r.shape[2] + ) + + o = self.g_norm(rearrange(o, '... h d -> ... (h d)')) * self.gate_fn(g) + o = self.o_proj(o) + + return o, None, past_key_values + + +class LoRA(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: int, + bias: Optional[bool] = True, + activation: Optional[str] = 'tanh' + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + self.bias = bias + + if activation is None: + self.activation = nn.Identity() + elif activation == 'sigmoid': + self.activation = nn.Sigmoid() + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'relu': + self.activation = nn.ReLU() + else: + raise ValueError(f"Not supported activation `{activation}`.") + + self.lora = nn.Sequential( + nn.Linear(input_dim, low_rank_dim, bias=False), + self.activation, + nn.Linear(low_rank_dim, output_dim, bias=bias) + ) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(" + s += f"input_dim={self.input_dim}, low_rank_dim={self.low_rank_dim}, output_dim={self.output_dim}" + if not self.bias: + s += f", bias={self.bias}" + s += ")" + return s + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lora(x) + + +class LerpLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: Optional[int] = None + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + self.mu = nn.Parameter(torch.zeros(input_dim)) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}" + if self.low_rank_dim is not None: + s += f", low_rank_dim={self.low_rank_dim}" + s += ")" + return s + + def forward(self, x: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor: + if delta is None: + shifted = self.time_shift(x) + if len(shifted.shape) == 2: + shifted = shifted.unsqueeze(1) + delta = shifted - x + return self.linear(x + delta * self.mu) + + +class DDLerpLinear(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + low_rank_dim: Optional[int] = None + ): + super().__init__() + + self.input_dim = input_dim + self.output_dim = output_dim + self.low_rank_dim = low_rank_dim + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + if low_rank_dim is None: + self.linear = nn.Linear(input_dim, output_dim, bias=False) + else: + self.linear = LoRA(input_dim, output_dim, low_rank_dim) + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}({self.input_dim}, {self.output_dim}" + if self.low_rank_dim is not None: + s += f", low_rank_dim={self.low_rank_dim}" + s += ")" + return s + + def forward(self, x: torch.Tensor, mu: torch.Tensor, delta: Optional[torch.Tensor] = None) -> torch.Tensor: + if delta is None: + shifted = self.time_shift(x) + if len(shifted.shape) == 2: + shifted = shifted.unsqueeze(1) + delta = shifted - x + return self.linear(x + delta * mu) diff --git a/fla/layers/simple_gla.py b/fla/layers/simple_gla.py new file mode 100644 index 0000000000000000000000000000000000000000..0771a90ce2e15b1a43f4ffa081c63a37ce3ea124 --- /dev/null +++ b/fla/layers/simple_gla.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution +from fla.modules.activations import ACT2FN +from fla.ops.simple_gla import chunk_simple_gla, fused_recurrent_simple_gla + +if TYPE_CHECKING: + from fla.models.utils import Cache + + +class SimpleGatedLinearAttention(nn.Module): + r""" + The layer implementaion for [Gated Linear Attention Transformers with Hardware-Efficient Training](https://arxiv.org/abs/2312.06635). # noqa + This layer calls the simplified GLA kernel in which the gating is head-wise instead of elementwise. + + Args: + mode (str, Optional): + Which GLA kernel to use. + Currently available: `chunk`. + Default: `chunk`. + hidden_size (int, Optional): + The hidden size of the input. Default: 1024. + expand_k (float, Optional): + The expansion ratio for the key dim. Default: 1.0. + expand_v (float, Optional): + The expansion ratio for the value dim. Default: 1.0. + num_heads (int, Optional): + The number of heads. Default: 4. + num_kv_heads (int, Optional): + The number of key/value heads, used for MQA. Default: None. + feature_map (str, Optional): + Feature map function applied to queries/keys. Default: None. + use_short_conv (bool, Optional): + Whether to use short convolutions. Default: `False`. + conv_size (int, Optional): + The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. + conv_bias (bool, Optional): + Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. + gate_fn (str, Optional): + The activation function for the output gate. Default: `swish`. + elementwise_affine (bool, Optional): + If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. + norm_eps (float, Optional): + The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. + gate_logit_normalizer (int, Optional): + The normalizer for the gate logits, appied after `logsigmoid`. Default: 16. + fuse_norm (bool, Optional): + Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. + layer_idx (int, Optional): + The index of the layer. Default: None. + """ + + def __init__( + self, + mode: str = 'chunk', + hidden_size: int = 1024, + expand_k: float = 1., + expand_v: float = 1., + num_heads: int = 4, + num_kv_heads: Optional[int] = None, + feature_map: Optional[str] = None, + use_short_conv: bool = True, + conv_size: int = 4, + conv_bias: bool = False, + gate_fn: str = 'swish', + elementwise_affine: Optional[bool] = True, + norm_eps: float = 1e-5, + gate_logit_normalizer: int = 16, + fuse_norm: bool = True, + layer_idx: int = None, + ) -> SimpleGatedLinearAttention: + super().__init__() + + self.mode = mode + self.hidden_size = hidden_size + self.expand_k = expand_k + self.expand_v = expand_v + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None + + self.use_short_conv = use_short_conv + self.conv_size = conv_size + self.conv_bias = conv_bias + + self.key_dim = int(hidden_size * expand_k) + self.value_dim = int(hidden_size * expand_v) + self.key_dim_per_group = self.key_dim // self.num_kv_groups + self.value_dim_per_group = self.value_dim // self.num_kv_groups + self.layer_idx = layer_idx + + assert mode in ['chunk', "fused_recurrent"], f"Not suppoerted mode `{mode}`." + assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" + assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" + + self.head_k_dim = self.key_dim // num_heads + self.head_v_dim = self.value_dim // num_heads + + self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) + self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) + self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) + + if use_short_conv: + self.conv_size = conv_size + self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') + self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') + self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') + + self.gk_proj = nn.Linear(hidden_size, self.num_heads) + + if gate_fn == 'swish' and fuse_norm: + self.g_norm_swish_gate = FusedRMSNormGated( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.fuse_norm_and_gate = True + else: + self.fuse_norm_and_gate = False + self.g_norm = RMSNorm( + hidden_size=self.head_v_dim, + elementwise_affine=elementwise_affine, + eps=norm_eps + ) + self.gate_fn = ACT2FN[gate_fn] + self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) + + self.gate_logit_normalizer = gate_logit_normalizer + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: + if attention_mask is not None: + assert len(attention_mask.shape) == 2, ( + "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " + "for padding purposes (0 indicating padding). " + "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." + ) + + # launching the triton kernel for just one token will actually be slower + mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode + + last_state = None + if past_key_values is not None and len(past_key_values) > self.layer_idx: + last_state = past_key_values[self.layer_idx] + + cu_seqlens = kwargs.get('cu_seqlens', None) + if self.use_short_conv: + conv_state_q, conv_state_k, conv_state_v = None, None, None + if last_state is not None: + conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] + conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None + q, conv_state_q = self.q_conv1d( + x=self.q_proj(hidden_states), + mask=conv_mask, + cache=conv_state_q, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + k, conv_state_k = self.k_conv1d( + x=self.k_proj(hidden_states), + mask=conv_mask, + cache=conv_state_k, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + v, conv_state_v = self.v_conv1d( + x=self.v_proj(hidden_states), + mask=conv_mask, + cache=conv_state_v, + output_final_state=use_cache, + cu_seqlens=cu_seqlens + ) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + gk = self.gk_proj(hidden_states) + + if self.feature_map_fn is not None: + q, k = map(self.feature_map_fn, (q, k)) + # dealing with left-padding + if attention_mask is not None: + v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) + q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) + if self.num_kv_groups > 1: + k, v = (repeat(x, '... (h d) -> ... (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) for x in (k, v)) + else: + k, v = (rearrange(x, '... (h d) -> ... h d', h=self.num_kv_heads) for x in (k, v)) + gk = F.logsigmoid(gk) / self.gate_logit_normalizer + + recurrent_state = last_state['recurrent_state'] if last_state is not None else None + if mode == 'chunk': + o, recurrent_state = chunk_simple_gla( + q=q, + k=k, + v=v, + gk=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + elif mode == 'fused_recurrent': + o, recurrent_state = fused_recurrent_simple_gla( + q=q, + k=k, + v=v, + gk=gk, + initial_state=recurrent_state, + output_final_state=use_cache, + cu_seqlens=cu_seqlens, + head_first=False + ) + else: + raise NotImplementedError(f"Not supported mode `{mode}`.") + + if past_key_values is not None: + past_key_values.update( + recurrent_state=recurrent_state, + conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, + layer_idx=self.layer_idx, + offset=q.shape[1] + ) + + g = self.g_proj(hidden_states) + if self.fuse_norm_and_gate: + g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads) + o = self.g_norm_swish_gate(o, g) + o = rearrange(o, 'b t h d -> b t (h d)') + else: + o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)') + o = o * self.gate_fn(g) + o = self.o_proj(o) + + return o, None, past_key_values + + def state_size(self, **kwargs) -> int: + state_size = self.key_dim * self.head_v_dim + for module in self.children(): + if isinstance(module, ShortConvolution): + state_size += module.state_size + return state_size diff --git a/fla/ops/__init__.py b/fla/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a66b57860c95675ef46823e79431ae8868a1229e --- /dev/null +++ b/fla/ops/__init__.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- + +from .abc import chunk_abc +from .attn import parallel_attn, parallel_rectified_attn, parallel_softpick_attn, naive_attn, naive_rectified_attn, naive_softpick_attn +from .based import fused_chunk_based, parallel_based +from .delta_rule import chunk_delta_rule, fused_chunk_delta_rule, fused_recurrent_delta_rule +from .forgetting_attn import parallel_forgetting_attn +from .gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule +from .generalized_delta_rule import ( + chunk_dplr_delta_rule, + chunk_iplr_delta_rule, + fused_recurrent_dplr_delta_rule, + fused_recurrent_iplr_delta_rule +) +from .gla import chunk_gla, fused_chunk_gla, fused_recurrent_gla +from .gsa import chunk_gsa, fused_recurrent_gsa +from .hgrn import fused_recurrent_hgrn +from .lightning_attn import chunk_lightning_attn, fused_recurrent_lightning_attn +from .linear_attn import chunk_linear_attn, fused_chunk_linear_attn, fused_recurrent_linear_attn +from .nsa import parallel_nsa +from .retention import chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention +from .rwkv6 import chunk_rwkv6, fused_recurrent_rwkv6 +from .rwkv7 import chunk_rwkv7, fused_recurrent_rwkv7 +from .simple_gla import chunk_simple_gla, fused_recurrent_simple_gla, parallel_simple_gla + +__all__ = [ + 'chunk_abc', + 'parallel_attn', 'parallel_rectified_attn', 'parallel_softpick_attn', + 'naive_attn', 'naive_rectified_attn', 'naive_softpick_attn', + 'fused_chunk_based', 'parallel_based', + 'chunk_delta_rule', 'fused_chunk_delta_rule', 'fused_recurrent_delta_rule', + 'parallel_forgetting_attn', + 'chunk_gated_delta_rule', 'fused_recurrent_gated_delta_rule', + 'chunk_dplr_delta_rule', 'chunk_iplr_delta_rule', + 'fused_recurrent_dplr_delta_rule', 'fused_recurrent_iplr_delta_rule', + 'chunk_gla', 'fused_chunk_gla', 'fused_recurrent_gla', + 'chunk_gsa', 'fused_recurrent_gsa', + 'fused_recurrent_hgrn', + 'chunk_lightning_attn', 'fused_recurrent_lightning_attn', + 'chunk_linear_attn', 'fused_chunk_linear_attn', 'fused_recurrent_linear_attn', + 'parallel_nsa', + 'chunk_retention', 'fused_chunk_retention', 'fused_recurrent_retention', 'parallel_retention', + 'chunk_rwkv6', 'fused_recurrent_rwkv6', + 'chunk_rwkv7', 'fused_recurrent_rwkv7', + 'chunk_simple_gla', 'fused_recurrent_simple_gla', 'parallel_simple_gla', +] diff --git a/fla/ops/attn/__init__.py b/fla/ops/attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..083462677513fb00d5c658218b41188d0216ba3b --- /dev/null +++ b/fla/ops/attn/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_attn +from .parallel_rectified import parallel_rectified_attn +from .parallel_softpick import parallel_softpick_attn +from .naive import naive_attn +from .naive_rectified import naive_rectified_attn +from .naive_softpick import naive_softpick_attn + +__all__ = [ + 'parallel_attn', + 'parallel_rectified_attn', + 'parallel_softpick_attn', + 'naive_attn', + 'naive_rectified_attn', + 'naive_softpick_attn', +] diff --git a/fla/ops/attn/__pycache__/__init__.cpython-311.pyc b/fla/ops/attn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8e54074abbf6fb55a325c06d497f76552092abf Binary files /dev/null and b/fla/ops/attn/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/attn/__pycache__/naive.cpython-311.pyc b/fla/ops/attn/__pycache__/naive.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d72774176835c20d8d425b219bb78fb79992433 Binary files /dev/null and b/fla/ops/attn/__pycache__/naive.cpython-311.pyc differ diff --git a/fla/ops/attn/__pycache__/naive_rectified.cpython-311.pyc b/fla/ops/attn/__pycache__/naive_rectified.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4b13c57272d2fb8f527137298b2a8cfd0e71466 Binary files /dev/null and b/fla/ops/attn/__pycache__/naive_rectified.cpython-311.pyc differ diff --git a/fla/ops/attn/__pycache__/parallel.cpython-311.pyc b/fla/ops/attn/__pycache__/parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de56afb71a872be0babaa7775c6053fcde5aaca Binary files /dev/null and b/fla/ops/attn/__pycache__/parallel.cpython-311.pyc differ diff --git a/fla/ops/attn/naive.py b/fla/ops/attn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..89e9acc24a8e5854df7097e7ac80ae4815939caa --- /dev/null +++ b/fla/ops/attn/naive.py @@ -0,0 +1,28 @@ +import torch +from typing import Optional +from einops import rearrange + +def naive_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + head_dim = q.shape[-1] + if scale is None: + scale = 1.0 / (head_dim ** 0.5) + if not head_first: + q, k, v = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (q, k, v)) + q_len = q.shape[-2] + k_len = k.shape[-2] + mask = torch.tril(torch.ones(k_len, k_len, device=q.device)) + wei = torch.matmul(q, k.transpose(2, 3)) # shape: (batch_size, num_heads, q_len, k_len) + wei = wei * scale + wei = wei.masked_fill(mask[k_len-q_len:k_len, :k_len] == 0, float('-inf')) + wei = torch.softmax(wei.float(), dim=-1).to(q.dtype) + o = torch.matmul(wei, v) # shape: (batch_size, num_heads, q_len, head_dim) + if not head_first: + o = rearrange(o, 'b h t d -> b t h d') + return o, wei \ No newline at end of file diff --git a/fla/ops/attn/naive_rectified.py b/fla/ops/attn/naive_rectified.py new file mode 100644 index 0000000000000000000000000000000000000000..1438115da09b6726f31cd7dcf356cab2a457a433 --- /dev/null +++ b/fla/ops/attn/naive_rectified.py @@ -0,0 +1,30 @@ +import torch +from typing import Optional +from einops import rearrange + +def naive_rectified_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + head_dim = q.shape[-1] + if scale is None: + scale = 1.0 / (head_dim ** 0.5) + if not head_first: + q, k, v = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (q, k, v)) + q_len = q.shape[-2] + k_len = k.shape[-2] + mask = torch.tril(torch.ones(k_len, k_len, device=q.device)) + wei = torch.matmul(q, k.transpose(2, 3)) # shape: (batch_size, num_heads, q_len, k_len) + wei = wei * scale + wei = torch.where(wei >= 0, wei, float('-inf')) + wei = wei.masked_fill(mask[k_len-q_len:k_len, :k_len] == 0, float('-inf')) + wei = torch.softmax(wei.float(), dim=-1).to(q.dtype) + wei = torch.nan_to_num(wei, nan=0.0) + o = torch.matmul(wei, v) # shape: (batch_size, num_heads, q_len, head_dim) + if not head_first: + o = rearrange(o, 'b h t d -> b t h d') + return o, wei \ No newline at end of file diff --git a/fla/ops/attn/naive_softpick.py b/fla/ops/attn/naive_softpick.py new file mode 100644 index 0000000000000000000000000000000000000000..f794ef2a968d52fc891e1477ba141716c1a3c6aa --- /dev/null +++ b/fla/ops/attn/naive_softpick.py @@ -0,0 +1,39 @@ +import torch +import torch.nn.functional as F +from typing import Optional +from einops import rearrange + +def softpick(x, dim=-1, eps=1e-8): + # softpick function: relu(exp(x)-1) / sum(abs(exp(x)-1)) + # numerically stable version + x_m = torch.max(x, dim=dim, keepdim=True).values + x_m_e_m = torch.exp(-x_m) + x_e_1 = torch.exp(x - x_m) - x_m_e_m + r_x_e_1 = F.relu(x_e_1) + a_x_e_1 = torch.where(x.isfinite(), torch.abs(x_e_1), 0) + return r_x_e_1 / (torch.sum(a_x_e_1, dim=dim, keepdim=True) + eps) # epsilon is only useful if all inputs are EXACTLY 0. we might not even need it + +def naive_softpick_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + head_dim = q.shape[-1] + if scale is None: + scale = 1.0 / (head_dim ** 0.5) + if not head_first: + q, k, v = map(lambda x: rearrange(x, 'b t h d -> b h t d'), (q, k, v)) + q_len = q.shape[-2] + k_len = k.shape[-2] + mask = torch.tril(torch.ones(k_len, k_len, device=q.device)) + wei = torch.matmul(q, k.transpose(2, 3)) # shape: (batch_size, num_heads, q_len, k_len) + wei = wei * scale + wei = wei.masked_fill(mask[k_len-q_len:k_len, :k_len] == 0, float('-inf')) + wei = softpick(wei.float(), dim=-1).to(q.dtype) + o = torch.matmul(wei, v) # shape: (batch_size, num_heads, q_len, head_dim) + if not head_first: + o = rearrange(o, 'b h t d -> b t h d') + return o, wei \ No newline at end of file diff --git a/fla/ops/attn/parallel.py b/fla/ops/attn/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d19a2e1b13398bf81cb503564c8652ff5735eee3 --- /dev/null +++ b/fla/ops/attn/parallel.py @@ -0,0 +1,629 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit +def parallel_attn_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + b_m = tl.full([BT], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([BT], dtype=tl.float32) + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + + # [BT, BS] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')) + + # [BT] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + b_o = b_o / b_acc[:, None] + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit +def parallel_attn_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_attn_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_attn_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + o_k = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + if check_shared_mem('hopper', q.device.index): + BS = min(64, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(256, max(16, triton.next_power_of_2(V))) + elif check_shared_mem('ampere', q.device.index): + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(128, max(16, triton.next_power_of_2(V))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(64, max(16, triton.next_power_of_2(V))) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + assert NK == 1, "The key dimension can not be larger than 256" + + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + grid = (NV, NT, B * HQ) + parallel_attn_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_attn_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_attn_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + scale: float = None, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BS = max(16, triton.next_power_of_2(T)) + BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS) + BK = max(16, triton.next_power_of_2(K)) + BV = max(16, triton.next_power_of_2(V)) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + delta = parallel_attn_bwd_preprocess(o, do) + + dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device) + grid = (NV, NT, B * HQ) + parallel_attn_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + parallel_attn_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum') + return dq, dk, dv + + +@torch.compile +class ParallelAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, offsets): + ctx.dtype = q.dtype + + chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1]))) + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + o, lse = parallel_attn_fwd( + q=q, + k=k, + v=v, + scale=scale, + chunk_size=chunk_size, + offsets=offsets, + indices=indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_attn_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=ctx.offsets, + indices=ctx.indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + o = ParallelAttentionFunction.apply(q, k, v, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/attn/parallel_rectified.py b/fla/ops/attn/parallel_rectified.py new file mode 100644 index 0000000000000000000000000000000000000000..4025c99fcf1c63c63083c52fb5ed8536bd82fde3 --- /dev/null +++ b/fla/ops/attn/parallel_rectified.py @@ -0,0 +1,643 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit +def parallel_rect_attn_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + b_m = tl.full([BT], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([BT], dtype=tl.float32) + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + + # [BT, BS] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + # b_p = exp(b_s - b_m[:, None]) + # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_m[:, None])) + b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_m[:, None])) # Just do this + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')) + + # [BT] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + # b_p = exp(b_s - b_m[:, None]) + # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_m[:, None])) + b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_m[:, None])) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + # b_o = b_o / b_acc[:, None] + b_o = tl.where(b_acc[:, None] == 0, 0.0, b_o / b_acc[:, None]) + # b_m += tl.log(b_acc) + b_m = tl.where(b_acc == 0, 0.0, b_m + tl.log(b_acc)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit +def parallel_rect_attn_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_rect_attn_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BS] + b_s = tl.dot(b_q, b_k) + # b_p = exp(b_s - b_lse[:, None]) + # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[:, None])) + b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[:, None])) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BS] + b_s = tl.dot(b_q, b_k) + # b_p = exp(b_s - b_lse[:, None]) + # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[:, None])) + b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[:, None])) + b_p = tl.where(o_q[:, None] >= o_k[None, :], b_p, 0) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_rect_attn_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + o_k = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + # b_p = exp(b_s - b_lse[None, :]) + # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[None, :])) + b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[None, :])) + b_p = tl.where(o_k[:, None] <= o_q[None, :], b_p, 0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + # b_p = exp(b_s - b_lse[None, :]) + # b_p = exp(tl.where(b_s < 0, float('-inf'), b_s - b_lse[None, :])) + b_p = tl.where(b_s < 0, 0.0, exp(b_s - b_lse[None, :])) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_rect_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + if check_shared_mem('hopper', q.device.index): + BS = min(64, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(256, max(16, triton.next_power_of_2(V))) + elif check_shared_mem('ampere', q.device.index): + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(128, max(16, triton.next_power_of_2(V))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(64, max(16, triton.next_power_of_2(V))) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + assert NK == 1, "The key dimension can not be larger than 256" + + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + grid = (NV, NT, B * HQ) + parallel_rect_attn_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_rect_attn_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_rect_attn_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_rect_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + scale: float = None, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BS = max(16, triton.next_power_of_2(T)) + BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS) + BK = max(16, triton.next_power_of_2(K)) + BV = max(16, triton.next_power_of_2(V)) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + delta = parallel_rect_attn_bwd_preprocess(o, do) + + dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device) + grid = (NV, NT, B * HQ) + parallel_rect_attn_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + parallel_rect_attn_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum') + return dq, dk, dv + + +@torch.compile +class ParallelRectifiedAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, offsets): + ctx.dtype = q.dtype + + chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1]))) + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + o, lse = parallel_rect_attn_fwd( + q=q, + k=k, + v=v, + scale=scale, + chunk_size=chunk_size, + offsets=offsets, + indices=indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_rect_attn_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=ctx.offsets, + indices=ctx.indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_rectified_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + o = ParallelRectifiedAttentionFunction.apply(q, k, v, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/attn/parallel_softpick.py b/fla/ops/attn/parallel_softpick.py new file mode 100644 index 0000000000000000000000000000000000000000..ec23b387050d42794bc705e2c4b8c69125323121 --- /dev/null +++ b/fla/ops/attn/parallel_softpick.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit +def parallel_softpick_attn_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + b_m = tl.full([BT], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([BT], dtype=tl.float32) + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + + # [BT, BS] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) - exp(-b_m[:, None]) + b_p_r = tl.maximum(b_p, 0.0) + b_p_a = tl.abs(b_p) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p_a, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p_r.to(b_q.dtype), b_v) + + b_mp = b_m + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')) + + # [BT] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) - exp(-b_m[:, None]) + b_p_r = tl.maximum(b_p, 0.0) + b_p_a = tl.abs(b_p) + b_p_a = tl.where(o_q[:, None] >= o_k[None, :], b_p_a, 0) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p_a, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p_r.to(b_q.dtype), b_v) + + b_mp = b_m + b_acc += 1e-6 # harcoded epsilon... sorry + b_o = b_o / b_acc[:, None] + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit +def parallel_softpick_attn_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_softpick_attn_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_e = exp(b_s - b_lse[:, None]) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + # [BT, BS] + b_step = tl.where(b_s > 0, b_dp, 0) + b_sign = tl.where(b_s > 0, b_delta[:, None], -b_delta[:, None]) + b_ds = b_e * (b_step.to(tl.float32) - b_sign) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_e = exp(b_s - b_lse[:, None]) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + # [BT, BS] + b_e = tl.where(o_q[:, None] >= o_k[None, :], b_e, 0) + b_step = tl.where(b_s > 0, b_dp, 0) + b_sign = tl.where(b_s > 0, b_delta[:, None], -b_delta[:, None]) + b_ds = b_e * (b_step.to(tl.float32) - b_sign) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_softpick_attn_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + o_k = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_e = exp(b_s - b_lse[None, :]) + b_p = b_e - exp(-b_lse[None, :]) + b_p_r = tl.maximum(b_p, 0.0) + b_p_r = tl.where(o_k[:, None] <= o_q[None, :], b_p_r, 0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p_r.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_e = tl.where(o_k[:, None] <= o_q[None, :], b_e, 0) + b_step = tl.where(b_s > 0, b_dp, 0) + b_sign = tl.where(b_s > 0, b_delta[None, :], -b_delta[None, :]) + b_ds = b_e * (b_step - b_sign) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + for i_s in range((i_t + 1) * BT, tl.cdiv(T, BS) * BS, BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_e = exp(b_s - b_lse[None, :]) + b_p = b_e - exp(-b_lse[None, :]) + b_p_r = tl.maximum(b_p, 0.0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p_r.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_step = tl.where(b_s > 0, b_dp, 0) + b_sign = tl.where(b_s > 0, b_delta[None, :], -b_delta[None, :]) + b_ds = b_e * (b_step - b_sign) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_softpick_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + if check_shared_mem('hopper', q.device.index): + BS = min(64, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(256, max(16, triton.next_power_of_2(V))) + elif check_shared_mem('ampere', q.device.index): + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(128, max(16, triton.next_power_of_2(V))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = min(256, max(16, triton.next_power_of_2(K))) + BV = min(64, max(16, triton.next_power_of_2(V))) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + assert NK == 1, "The key dimension can not be larger than 256" + + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + grid = (NV, NT, B * HQ) + parallel_softpick_attn_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_softpick_attn_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_softpick_attn_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_softpick_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + scale: float = None, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BS = max(16, triton.next_power_of_2(T)) + BS = min(32, BS) if check_shared_mem('ampere') else min(16, BS) + BK = max(16, triton.next_power_of_2(K)) + BV = max(16, triton.next_power_of_2(V)) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + delta = parallel_softpick_attn_bwd_preprocess(o, do) + + dq = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dk = torch.empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float, device=q.device) + dv = torch.empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float, device=q.device) + grid = (NV, NT, B * HQ) + parallel_softpick_attn_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + parallel_softpick_attn_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum') + return dq, dk, dv + + +@torch.compile +class ParallelSoftpickAttentionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, offsets): + ctx.dtype = q.dtype + + chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1]))) + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + o, lse = parallel_softpick_attn_fwd( + q=q, + k=k, + v=v, + scale=scale, + chunk_size=chunk_size, + offsets=offsets, + indices=indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_softpick_attn_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=ctx.offsets, + indices=ctx.indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_softpick_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + o = ParallelSoftpickAttentionFunction.apply(q, k, v, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/lightning_attn/__init__.py b/fla/ops/lightning_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c28c3af59f61d32cbb68a63926ac67fa2bb73447 --- /dev/null +++ b/fla/ops/lightning_attn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_lightning_attn +from .fused_recurrent import fused_recurrent_lightning_attn + +__all__ = [ + 'chunk_lightning_attn', + 'fused_recurrent_lightning_attn' +] diff --git a/fla/ops/lightning_attn/__pycache__/__init__.cpython-311.pyc b/fla/ops/lightning_attn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38b52facedaa5d2f8196e3076e8a6e4e433c9ec7 Binary files /dev/null and b/fla/ops/lightning_attn/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/lightning_attn/__pycache__/chunk.cpython-311.pyc b/fla/ops/lightning_attn/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c2a6a78f06d12ad9affc7e01c1a1a8b57a0478f Binary files /dev/null and b/fla/ops/lightning_attn/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..367720e7463a22b3aa80f8ff62dc51a6b4612c82 Binary files /dev/null and b/fla/ops/lightning_attn/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/lightning_attn/chunk.py b/fla/ops/lightning_attn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..d56d913acbdf72a827858bf65250fcd4f70e681d --- /dev/null +++ b/fla/ops/lightning_attn/chunk.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.chunk import chunk_simple_gla + + +@torch.compiler.disable +def chunk_lightning_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + num_layers: int, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + layer_idx (int): + The index of the current layer. + num_layers (int): + The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + H = q.shape[1] if head_first else q.shape[2] + s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + return chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + head_first=head_first, + cu_seqlens=cu_seqlens + ) diff --git a/fla/ops/lightning_attn/fused_recurrent.py b/fla/ops/lightning_attn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..6548188b7b617a994316696b8ee1237b064029c4 --- /dev/null +++ b/fla/ops/lightning_attn/fused_recurrent.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla + + +def fused_recurrent_lightning_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + num_layers: int, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + layer_idx (int): + The index of the current layer. + num_layers (int): + The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + H = q.shape[1] if head_first else q.shape[2] + s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) + if head_first: + g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + else: + g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() + return fused_recurrent_simple_gla( + q=q, + k=k, + v=v, + g=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + head_first=head_first + ) diff --git a/fla/ops/rwkv6/__init__.py b/fla/ops/rwkv6/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3c7c218eb873a1a2115b5587530fe55f29a9d02 --- /dev/null +++ b/fla/ops/rwkv6/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_rwkv6 +from .fused_recurrent import fused_recurrent_rwkv6 + +__all__ = [ + 'chunk_rwkv6', + 'fused_recurrent_rwkv6' +] diff --git a/fla/ops/rwkv6/__pycache__/__init__.cpython-311.pyc b/fla/ops/rwkv6/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bb3632479e163c1550dec566ee6708d0cf13a5e Binary files /dev/null and b/fla/ops/rwkv6/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/rwkv6/__pycache__/chunk.cpython-311.pyc b/fla/ops/rwkv6/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58a109dc3112629a1e8168da0c6d50dff73ee535 Binary files /dev/null and b/fla/ops/rwkv6/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ef53f6487ed2e9131f53df1e860433e84e171f8 Binary files /dev/null and b/fla/ops/rwkv6/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/rwkv6/chunk.py b/fla/ops/rwkv6/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..b495dbb21f3023b5af2f744793b0fa58c90dc9e8 --- /dev/null +++ b/fla/ops/rwkv6/chunk.py @@ -0,0 +1,1465 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_h import chunk_fwd_h +from fla.ops.gla.chunk import chunk_gla_bwd_dA, chunk_gla_bwd_dv, chunk_gla_fwd_o_gk +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps, num_stages=num_stages) + for BS in [16, 32, 64] + for num_warps in [4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['S', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_cumsum_kernel( + s, + oi, + oe, + offsets, + indices, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + m_i = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32) + m_e = tl.where(o_i[:, None] > o_i[None, :], 1., 0.).to(tl.float32) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oi = tl.make_block_ptr(oi + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oe = tl.make_block_ptr(oe + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oi = tl.make_block_ptr(oi + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_oe = tl.make_block_ptr(oe + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_oi = tl.dot(m_i, b_s) + b_oe = tl.dot(m_e, b_s) + tl.store(p_oi, b_oi.to(p_oi.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_oe, b_oe.to(p_oe.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_rwkv6_fwd_cumsum( + g: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + gi, ge = torch.empty_like(g, dtype=torch.float), torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + chunk_rwkv6_fwd_cumsum_kernel[grid]( + g, + gi, + ge, + offsets, + indices, + T=T, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first + ) + return gi, ge + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_inter( + q, + k, + gi, # cumulative decay inclusive + ge, # cumulative decay exclusive + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + m_i = i_t * BT + i_i * BC + tl.arange(0, BC) < T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.where(m_i[:, None] & m_k, tl.load(p_gq, boundary_check=(0, 1)), float('-inf')) + b_qg = b_q * exp(b_gq - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[:, None] - b_gk) + # [BC, BC] using tf32 to improve precision here. + b_A += tl.dot(b_qg, b_kg) + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra( + q, + k, + gi, + ge, + u, + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + else: + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_qj = q + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_kj = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = gi + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (0,), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_kj[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A * scale, 0.) + b_A = tl.where(o_i != j, b_A, tl.sum(b_qj * b_kj * b_u * scale)) + tl.store(A + o_A + j, b_A, mask=m_A) + p_qj += K if HEAD_FIRST else H*K + p_kj += K if HEAD_FIRST else H*K + p_gk += K if HEAD_FIRST else H*K + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC', 'BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split( + q, + k, + gi, + ge, + u, + A, + offsets, + indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_tc // NC, i_tc % NC + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + if HEAD_FIRST: + o_A = (i_k * B*H + i_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + else: + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_qj = q + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_kj = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = gi + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (i_k * BK), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_kj[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i > j, b_A * scale, 0.) + b_A = tl.where(o_i != j, b_A, tl.sum(b_qj * b_kj * b_u * scale)) + tl.store(A + o_A + j, b_A, mask=m_A) + p_qj += K if HEAD_FIRST else H*K + p_kj += K if HEAD_FIRST else H*K + p_gk += K if HEAD_FIRST else H*K + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge( + A, + A2, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + NK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_c * BC >= T: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(0, NK): + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh)*T*BC, (T, BC), (BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + b_A += tl.load(p_A, boundary_check=(0, 1)) + if HEAD_FIRST: + p_A2 = tl.make_block_ptr(A2 + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + else: + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_dh( + q, + gi, + ge, + do, + dh, + dht, + dh0, + offsets, + chunk_offsets, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if HEAD_FIRST: + p_gk = tl.make_block_ptr(ge + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gi + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(ge + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gi + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk) * scale).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + b_dh += tl.dot(b_q, b_do) + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BK', 'NC', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_intra( + q, + k, + gi, + ge, + dA, + dq, + dk, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_ge = tl.make_block_ptr(ge + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_ge = tl.load(p_ge, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + if HEAD_FIRST: + p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) + else: + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h*K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[None, :] - b_gk) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg) + b_dq *= exp(b_ge - b_gn[None, :]) + + o_i = tl.arange(0, BC) + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + o_dA = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gkj = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gkj = gi + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] > j + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_ge - b_gkj[None, :]), 0.) + p_kj += K if HEAD_FIRST else H*K + p_gkj += K if HEAD_FIRST else H*K + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(gi + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + if HEAD_FIRST: + p_gn = gi + i_bh * T*K + (min(i_t * BT + i_i * BC + BC, T) - 1)*K + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + else: + p_gn = gi + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h*K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (BT, T), (1, BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.where(m_j[:, None] & m_k, tl.load(p_gq, boundary_check=(0, 1)), float('-inf')) + b_qg = b_q * exp(b_gq - b_gn[None, :]) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dk += tl.dot(b_dA, b_qg) + b_dk *= exp(b_gn[None, :] - b_gk) + if HEAD_FIRST: + o_dA = i_bh * T*BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gqj = tl.max_contiguous(tl.multiple_of(ge + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_dk = tl.make_block_ptr(dk + i_bh*T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gqj = ge + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j * (1 if HEAD_FIRST else H) * BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] < j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.) + p_qj += K if HEAD_FIRST else H*K + p_gqj += K if HEAD_FIRST else H*K + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_rwkv6_bwd_kernel_inter( + q, + k, + v, + h, + gi, + ge, + u, + do, + dh, + dA, + dq, + dk, + dq2, + dk2, + dg, + du, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_gk = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.max_contiguous(tl.multiple_of(gi + i_bh * T*K + (min(T, i_t * BT + BT)-1) * K + o_k, BK), BK) + else: + p_gk = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = gi + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK,], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_dgk *= exp(b_gn) + b_dq *= scale + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)) + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gi) + + o_i = tl.arange(0, BT) + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA_dig = dA + (i_bh * T + i_t * BT + o_i) * BT + o_i + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA_dig = dA + ((bos + i_t * BT + o_i) * H + i_h) * BT + o_i + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] - b_q * b_dq + # [BT,] + b_dA_dig = tl.load(p_dA_dig, mask=(i_t * BT + o_i) < T, other=0) + + p_u = tl.make_block_ptr(u + i_h * K, (K,), (1,), (i_k * BK,), (BK,), (0,)) + b_u = tl.load(p_u, boundary_check=(0,)) + # scale is already applied to b_dA_diag + b_dq += (b_dA_dig[:, None] * b_u[None, :] * b_k) + b_dk += (b_dA_dig[:, None] * b_u[None, :] * b_q) + b_du = tl.sum(b_dA_dig[:, None] * b_q * b_k, axis=0) + p_du = tl.make_block_ptr(du + (i_tg * H + i_h) * K, (K,), (1,), (i_k * BK,), (BK,), (0,)) + tl.store(p_du, b_du, boundary_check=(0,)) + + if HEAD_FIRST: + p_dq = tl.make_block_ptr(dq2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_rwkv6_fwd_intra( + q: torch.Tensor, + k: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + u: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_inter[grid]( + q, + k, + gi, + ge, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + HEAD_FIRST=head_first + ) + + grid = (NT, NC, B * H) + # load the entire [BC, K] blocks into SRAM at once + if K <= 256: + BK = triton.next_power_of_2(K) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra[grid]( + q, + k, + gi, + ge, + u, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + HEAD_FIRST=head_first + ) + # split then merge + else: + BK = min(128, triton.next_power_of_2(K)) + NK = triton.cdiv(K, BK) + A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float) + + grid = (NK, NT * NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split[grid]( + q, + k, + gi, + ge, + u, + A_intra, + offsets, + indices, + scale, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first + ) + + grid = (NT, NC, B * H) + chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge[grid]( + A_intra, + A, + offsets, + indices, + B=B, + T=T, + H=H, + BT=BT, + BC=BC, + NK=NK, + HEAD_FIRST=head_first + ) + return A + + +def chunk_rwkv6_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT = len(offsets) - 1, len(indices) + chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + NG = HQ // H + + if head_first: + dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_rwkv6_bwd_kernel_dh[grid]( + q=q, + gi=gi, + ge=ge, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + offsets=offsets, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + HEAD_FIRST=head_first + ) + return dh, dh0 + + +def chunk_rwkv6_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + dA: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K = q.shape + else: + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + grid = (NK, NT * NC, B * H) + chunk_rwkv6_bwd_kernel_intra[grid]( + q, + k, + gi, + ge, + dA, + dq, + dk, + offsets, + indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first + ) + return dq, dk + + +def chunk_rwkv6_bwd_dqkgu( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + u: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dA: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + dg = torch.empty_like(g) + du = u.new_empty(B * NT, H, K, dtype=torch.float) + def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) + chunk_rwkv6_bwd_kernel_inter[grid]( + q, + k, + v, + h, + gi, + ge, + u, + do, + dh, + dA, + dq, + dk, + dq2, + dk2, + dg, + du, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + du = du.sum(0) + return dq2, dk2, dg, du + + +def chunk_rwkv6_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, offsets=offsets, indices=indices, head_first=head_first) + h, ht = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=gi, + gv=None, + h0=initial_state, + output_final_state=output_final_state, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size, + states_in_fp32=True + ) + # the intra A is kept in fp32 + # the computation has very marginal effect on the entire throughput + A = chunk_rwkv6_fwd_intra( + q=q, + k=k, + gi=gi, + ge=ge, + u=u, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + o = chunk_gla_fwd_o_gk( + q=q, + v=v, + g=ge, + A=A, + h=h, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return A, h, ht, o + + +def chunk_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + gi, ge = chunk_rwkv6_fwd_cumsum(g, chunk_size=chunk_size, offsets=offsets, indices=indices, head_first=head_first) + h, _ = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=gi, + gv=None, + h0=initial_state, + output_final_state=False, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size, + states_in_fp32=True + ) + dh, dh0 = chunk_rwkv6_bwd_dh( + q=q, + k=k, + v=v, + gi=gi, + ge=ge, + do=do, + h0=initial_state, + dht=dht, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size, + states_in_fp32=True + ) + + # dq dk in fp32 + dA = chunk_gla_bwd_dA( + v=v, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dv = chunk_gla_bwd_dv( + k=k, + g=gi, + A=A, + do=do, + dh=dh, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dq, dk = chunk_rwkv6_bwd_dqk_intra( + q=q, + k=k, + gi=gi, + ge=ge, + dA=dA, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dq, dk, dg, du = chunk_rwkv6_bwd_dqkgu( + q=q, + k=k, + v=v, + h=h, + g=g, + gi=gi, + ge=ge, + u=u, + do=do, + dh=dh, + dA=dA, + dq=dq, + dk=dk, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return dq, dk, dv, dg, du, dh0 + + +class ChunkRWKV6Function(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + g, + u, + scale, + initial_state, + output_final_state, + offsets, + head_first + ): + T = q.shape[2] if head_first else q.shape[1] + chunk_size = min(32, max(32, triton.next_power_of_2(T))) if check_shared_mem() \ + else min(64, max(32, triton.next_power_of_2(T))) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + A, h, ht, o = chunk_rwkv6_fwd( + q=q, + k=k, + v=v, + g=g, + u=u, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + ctx.save_for_backward(q, k, v, g, initial_state, A, u) + + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, initial_state, A, u = ctx.saved_tensors + chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + dq, dk, dv, dg, du, dh0 = chunk_rwkv6_bwd( + q=q, + k=k, + v=v, + g=g, + u=u, + scale=scale, + initial_state=initial_state, + A=A, + do=do, + dht=dht, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), du.to(u), None, dh0, None, None, None + + +@torch.compiler.disable +def chunk_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + u (torch.Tensor): + bonus representations of shape `[H]`. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (Optional[torch.Tensor]): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.rwkv6 import chunk_rwkv6 + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> u = torch.randn(H, K, device='cuda') + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = chunk_rwkv6(q, k, v, g, u, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_rwkv6(q, k, v, g, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = ChunkRWKV6Function.apply( + q, + k, + v, + g, + u, + scale, + initial_state, + output_final_state, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/rwkv6/chunk_naive.py b/fla/ops/rwkv6/chunk_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4a2ac664f5079a20eabe9b11c19c1cff6755c658 --- /dev/null +++ b/fla/ops/rwkv6/chunk_naive.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def naive_chunk_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + chunk_size: int = 32 +): + assert q.shape[-2] % chunk_size == 0 + orig_dtype = q.dtype + num_chunk = q.shape[-2] // chunk_size + u = u.unsqueeze(0) + + q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) + + w_cumsum = w.cumsum(-2) + + kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() + wkv = kw.transpose(-1, -2) @ v + + wkv_new = torch.zeros_like(wkv) + + for i in range(num_chunk - 1): + wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] + + o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) + + o_intra = torch.zeros_like(o_inter) + for i in range(chunk_size): + attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) + mask = (torch.arange(0, chunk_size) < i).to(attn.device) + attn.masked_fill_(~mask, 0) + intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) + intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] + o_intra[:, :, :, i] = intra_inter_o + intra_intra_o + o = o_inter + o_intra + return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) diff --git a/fla/ops/rwkv6/fused_recurrent.py b/fla/ops/rwkv6/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..2ff9d2c3cd5492236c1120bd20d4b513584f1da4 --- /dev/null +++ b/fla/ops/rwkv6/fused_recurrent.py @@ -0,0 +1,709 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_fwd_kernel( + q, # query [B, H, T, K]/[B, T, H, K] + k, # key [B, H, T, K]/[B, T, H, K] + v, # value [B, H, T, V]/[B, T, H, V] + w, # log gate [B, H, T]/[B, T, H] or None + u, # bonus [B, H, K] + o, # output [NK, B, H, T, V]/[NK, B, T, H, V] + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, # whether to reverse the recurrence + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + p_w = w + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + o_v + else: + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_u = u + i_h * K + o_k + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32) + b_kv = b_k[:, None] * b_v[None, :] + b_o = tl.sum((b_h + b_kv * b_u[:, None]) * b_q[:, None], 0) + b_h = b_h * exp(b_w)[:, None] + b_kv + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_w += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + ], + key=['BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_bwd_kernel_dq( + k, # key [B, H, T, V]/[B, T, H, V] + v, # value [B, H, T, V]/[B, T, H, V] + w, # log gate [B, H, T]/[B, T, H] + u, # bonus [B, H, K] + do, # gradient of output [B, H, T, V]/[B, T, H, V] + dq, # gradient of query [NV, B, H, T, K]/[NV, B, T, H, K] + dq1, # gradient of query_aux [NV, B, H, T, K]/[NV, B, T, H, K] + h0, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + if HEAD_FIRST: + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + p_w = w + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_dq1 = dq1 + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + o_k + else: + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_w = w + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_dq1 = dq1 + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_u = u + i_h * K + o_k + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_kv = b_k[:, None] * b_v[None, :] + + b_hq = b_h * b_do[None, :] + b_dq = tl.sum(b_hq + b_kv * b_u[:, None] * b_do[None, :], 1) * scale + b_dq1 = tl.sum(b_hq, 1) + b_h = b_h * exp(b_w)[:, None] + b_h += b_kv + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) + tl.store(p_dq1, b_dq1.to(p_dq1.dtype.element_ty), mask=mask_k) + + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_w += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_dq1 += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + ], + key=['BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_bwd_kernel_dkv( + q, # query [B, H, T, K]/[B, T, H, K] + k, # key [B, H, T, V]/[B, T, H, V] + v, # value [B, H, T, V]/[B, T, H, V] + w, # log gate [B, H, T]/[B, T, H] + u, # bonus [B, H, K] + do, # gradient of output [B, H, T, V]/[B, T, H, V] + dk, # gradient of key [NV, B, H, T, K]/[NK, B, T, H, K] + dk1, # gradient of key_aux [NV, B, H, T, K]/[NK, B, T, H, K] + dv, # gradient of value [NK, B, H, T, V]/[NV, B, T, H, V] + dh0, # gradient of initial hidden state [N, H, K, V] + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T-1) * K if not REVERSE else 0) + o_k + p_k = k + i_nh * T*K + ((T-1) * K if not REVERSE else 0) + o_k + p_v = v + i_nh * T*V + ((T-1) * V if not REVERSE else 0) + o_v + p_w = w + i_nh * T*K + ((T-1) * K if not REVERSE else 0) + o_k + p_do = do + i_nh * T*V + ((T-1) * V if not REVERSE else 0) + o_v + p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T-1) * K if not REVERSE else 0) + o_k + p_dk1 = dk1 + (i_v * B*H + i_nh) * T*K + ((T-1) * K if not REVERSE else 0) + o_k + p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T-1) * V if not REVERSE else 0) + o_v + else: + p_q = q + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_w = w + (bos + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_do = do + (bos + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_dk = dk + ((i_v * all + bos) + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dk1 = dk1 + ((i_v * all + bos) + ((T-1) if not REVERSE else 0)) * H*K + i_h * K + o_k + p_dv = dv + ((i_k * all + bos) + ((T-1) if not REVERSE else 0)) * H*V + i_h * V + o_v + p_u = u + i_h * K + o_k + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_u = tl.load(p_u, mask=mask_k, other=0).to(tl.float32) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for _ in range(T - 1, -1, -1): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_w = tl.load(p_w, mask=mask_k, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_dkv = b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], 1) + tl.store(p_dk1, b_dk.to(p_dk1.dtype.element_ty), mask=mask_k) + b_dk += tl.sum(b_dkv * b_u[:, None] * b_v[None, :], 1) + b_dv = tl.sum((b_dh + (b_dkv * b_u[:, None])) * b_k[:, None], 0) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + b_dh *= exp(b_w)[:, None] + b_dh += b_dkv + + p_q += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_w += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_do += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_dk += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_dk1 += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_dv += (-1 if not REVERSE else 1) * (1 if HEAD_FIRST else H) * V + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT, 'BK': BK}, num_warps=num_warps) + for BT in [16, 32, 64] + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + ], + key=['K'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_rwkv6_bwd_kernel_dw( + q, + k, + dq, + dk, + dw, + offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + REVERSE: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_k, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + NT = tl.cdiv(T, BT) + + o_i = tl.arange(0, BT) + m_i = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) if not REVERSE else tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BK], dtype=tl.float32) + + i_t = 0 if not REVERSE else NT - 1 + for _ in range(NT): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (T, K), (K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + i_nh * T*K, (T, K), (K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_nh * T*K, (T-1, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_nh * T*K, (T-1, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + 1, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T-1, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T-1, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_dq = tl.load(p_dq, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_dk = tl.load(p_dk, boundary_check=(0, 1)).to(tl.float32) + b_dw = (b_q * b_dq * scale) - b_k * b_dk + b_c = b_z[None, :] + tl.dot(m_i, b_dw, allow_tf32=False) + tl.store(p_dw, b_c.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + if i_t >= 0: + b_z += tl.sum(b_dw, 0) + + i_t += (1 if not REVERSE else -1) + + +def fused_recurrent_rwkv6_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + ht = q.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + o = q.new_empty(NK, *v.shape, dtype=torch.float) + + grid = (NV, NK, N * H) + fused_recurrent_rwkv6_fwd_kernel[grid]( + q, + k, + v, + w, + u, + o, + h0, + ht, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + o = o.sum(0) + return o, ht + + +def fused_recurrent_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + do: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + + BK, BV = min(triton.next_power_of_2(K), 16), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape, dtype=torch.float) + dq1 = torch.empty_like(dq) + + grid = (NV, NK, N * H) + fused_recurrent_rwkv6_bwd_kernel_dq[grid]( + k, + v, + w, + u, + do, + dq, + dq1, + initial_state, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + dq = dq.sum(0) + dq1 = dq1.sum(0) + + BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dk = q.new_empty(NV, *k.shape, dtype=torch.float) + dk1 = q.new_empty(NV, *k.shape, dtype=torch.float) + dv = q.new_empty(NK, *v.shape, dtype=torch.float) + + dh0 = torch.empty_like(initial_state) if initial_state is not None else None + grid = (NV, NK, N * H) + fused_recurrent_rwkv6_bwd_kernel_dkv[grid]( + q, + k, + v, + w, + u, + do, + dk, + dk1, + dv, + dh0, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + dk = dk.sum(0) + dk1 = dk1.sum(0) + dv = dv.sum(0) + + dw = torch.empty_like(w) + def grid(meta): return (triton.cdiv(meta['K'], meta['BK']), N * H) + fused_recurrent_rwkv6_bwd_kernel_dw[grid]( + q, + k, + dq1, + dk1, + dw, + offsets, + scale, + T=T, + H=H, + K=K, + REVERSE=not reverse, + HEAD_FIRST=head_first + ) + du = (do.float() * v).sum(-1, True, dtype=torch.float) * q * k * scale + du = du.sum((0, 2)) if head_first else du.sum((0, 1)) + return dq, dk, dv, dw, du, dh0 + + +class FusedRecurrentRWKV6Function(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True + ): + o, ht = fused_recurrent_rwkv6_fwd( + q=q, + k=k, + v=v, + w=w, + u=u, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + offsets=offsets, + head_first=head_first + ) + ctx.save_for_backward(q, k, v, w, u, initial_state) + ctx.scale = scale + ctx.reverse = reverse + ctx.offsets = offsets + ctx.head_first = head_first + return o.to(v), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, w, u, initial_state = ctx.saved_tensors + + dq, dk, dv, dw, du, dh0 = fused_recurrent_rwkv6_bwd( + q=q, + k=k, + v=v, + w=w, + u=u, + do=do, + scale=ctx.scale, + initial_state=initial_state, + reverse=ctx.reverse, + offsets=ctx.offsets, + head_first=ctx.head_first + ) + dh0 = dh0.to(initial_state) if dh0 is not None else dh0 + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), du.to(u), None, dh0, None, None, None, None + + +def fused_recurrent_rwkv6( + r: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + r (torch.Tensor): + reception of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + Alias: q, query in linear attention. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + w (torch.Tensor): + data-dependent decays of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` in log space! Alias: g. + u (torch.Tensor): + bonus of shape `[H, K]` + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (Optional[torch.Tensor]): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.rwkv6 import fused_recurrent_rwkv6 + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> u = torch.randn(H, K, device='cuda') + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_rwkv6(q, k, v, g, u, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_rwkv6(q, k, v, g, u, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if r.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {r.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = FusedRecurrentRWKV6Function.apply( + r, + k, + v, + w, + u, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/rwkv6/recurrent_naive.py b/fla/ops/rwkv6/recurrent_naive.py new file mode 100644 index 0000000000000000000000000000000000000000..ba2268759b5d4ce7f9be1be1f9c2e1a2f2a8e6c3 --- /dev/null +++ b/fla/ops/rwkv6/recurrent_naive.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_rwkv6( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +): + orig_dtype = q.dtype + B, H, T, K, V = *q.shape, v.shape[-1] + q, k, v, w, u = map(lambda x: x.float(), (q, k, v, w, u)) + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + o = torch.zeros_like(v) + + if scale is None: + scale = K ** -0.5 + + if initial_state is not None: + h += initial_state + + for i in range(T): + q_i = q[:, :, i, :] * scale + k_i = k[:, :, i] + v_i = v[:, :, i, :] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + o_i = (h + u[None, ..., None] * kv_i) * q_i[..., None] + o[:, :, i] = o_i.sum(-2) + h = h * w_i[..., None] + kv_i + ht = h if output_final_state else None + return o.to(orig_dtype), ht + + +@torch.no_grad +@torch.jit.script +def naive_recurrent_rwkv6_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + initial_state: Optional[torch.Tensor] = None +): + q, k, v, w, u, o, do = (x.to(dtype=torch.float32) for x in (q, k, v, w, u, o, do)) + B, H, T, K, V = q.shape[0], q.shape[1], q.shape[2], q.shape[3], v.shape[-1] + h = torch.zeros(B, H, K, V, dtype=torch.float32, device=q.device) + dq = torch.zeros_like(q) + dq_aux = torch.zeros_like(q) + + if initial_state is not None: + h += initial_state + + for i in range(T): + k_i = k[:, :, i] + v_i = v[:, :, i] + w_i = w[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h_i = (h + u[None, ..., None] * kv_i) + dq_i = (do[:, :, i, None, :] * h_i).sum(-1) + dq_aux_i = (do[:, :, i, None, :] * h).sum(-1) + dq[:, :, i] = dq_i + dq_aux[:, :, i] = dq_aux_i + h = h * w_i[..., None] + kv_i + + du = torch.zeros_like(u) + dh = torch.zeros_like(h) + dk = torch.zeros_like(k) + dk_aux = torch.zeros_like(k) + dv = torch.zeros_like(v) + + for i in range(T - 1, -1, -1): + d_kv_i = do[:, :, i, None, :] * q[:, :, i, :, None] + k_i = k[:, :, i] + v_i = v[:, :, i] + du_i = (d_kv_i * k_i[..., None] * v_i[..., None, :]).sum(-1) + du += du_i.sum(0) + dk_i = (dh * v_i[..., None, :]).sum(-1) + dk_aux[:, :, i] = dk_i + dk_i += (d_kv_i * u[None, ..., None] * v_i[..., None, :]).sum(-1) + dv_i = (d_kv_i * u[None, ..., None] * k_i[..., None]).sum(-2) + dv_i += (dh * k_i[..., None]).sum(-2) + + dk[:, :, i] = dk_i + dv[:, :, i] = dv_i + dh = dh * w[:, :, i, :, None].exp() + d_kv_i + + # dw = q * dq_aux - k * dk_aux + dw = torch.zeros_like(w) + for i in range(T - 2, -1, -1): + dw[:, :, i] = dw[:, :, i+1] + dq_aux[:, :, i+1] * q[:, :, i+1] - dk_aux[:, :, i] * k[:, :, i] + + return dq, dk, dv, dw, du, dh diff --git a/fla/ops/rwkv7/__init__.py b/fla/ops/rwkv7/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5efa2f31d0fa68361a4ff19e65d797157e03c83d --- /dev/null +++ b/fla/ops/rwkv7/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_rwkv7 +from .fused_recurrent import fused_recurrent_rwkv7 + +__all__ = [ + 'chunk_rwkv7', + 'fused_recurrent_rwkv7' +] diff --git a/fla/ops/rwkv7/__pycache__/__init__.cpython-311.pyc b/fla/ops/rwkv7/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..414cf807d207fece082ddd4e86d4a11c077dcf7c Binary files /dev/null and b/fla/ops/rwkv7/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/rwkv7/__pycache__/chunk.cpython-311.pyc b/fla/ops/rwkv7/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a7ef3f13415fcd5ae57dd53761a481db03b16aa Binary files /dev/null and b/fla/ops/rwkv7/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/rwkv7/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/rwkv7/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7df0e051436611219f22f41bd370982e5ba7395 Binary files /dev/null and b/fla/ops/rwkv7/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/rwkv7/channel_mixing.py b/fla/ops/rwkv7/channel_mixing.py new file mode 100644 index 0000000000000000000000000000000000000000..991ea426f086859b8eaf0623f5acf059f9bddc5c --- /dev/null +++ b/fla/ops/rwkv7/channel_mixing.py @@ -0,0 +1,323 @@ +import logging + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_pytorch_version, input_guard, use_cuda_graph + +logger = logging.getLogger(__name__) + +if not check_pytorch_version('2.4'): + logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations') + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': block_size}) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] + ], + key=['hidden_dim'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def rwkv_seq_mix_kernel( + x_ptr, + x_prev_ptr, + mix_k_ptr, + output_ptr, + batch_size: tl.constexpr, + token_length, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr +): + block_start = tl.program_id(0) * BLOCK_SIZE + block_idx = block_start + tl.arange(0, BLOCK_SIZE)[:] + + total_seq_dim = token_length * hidden_dim + batch_idx = block_idx // total_seq_dim + seq_and_feat = block_idx % total_seq_dim + seq_idx = seq_and_feat // hidden_dim + feat_idx = seq_and_feat % hidden_dim + + is_valid = (batch_idx < batch_size) & (seq_idx < token_length) + + x_idx = batch_idx * total_seq_dim + seq_idx * hidden_dim + feat_idx + + curr_x = tl.load(x_ptr + x_idx, mask=is_valid, other=0.0).to(tl.float32) + k_value = tl.load(mix_k_ptr + feat_idx).to(tl.float32) + + is_first = seq_idx < 1 + prev_state_idx = batch_idx * hidden_dim + feat_idx + prev_state = tl.load(x_prev_ptr + prev_state_idx, + mask=(is_first & is_valid), + other=0.0).to(tl.float32) + + prev_x_idx = x_idx - hidden_dim + prev_x = tl.load(x_ptr + prev_x_idx, + mask=(~is_first & is_valid), + other=0.0).to(tl.float32) + + prev_value = tl.where(is_first, prev_state, prev_x) + state_diff = prev_value - curr_x + mixed = state_diff * k_value + result = tl.cast(curr_x + mixed, dtype=output_ptr.dtype.element_ty, fp_downcast_rounding='rtne') + tl.store(output_ptr + x_idx, result, mask=is_valid) + + +@triton.jit +def rwkv_channel_mixing_pow_and_relu( + in_ptr, + out_ptr, + BLOCK_SIZE: tl.constexpr +): + """Fused ReLU and Power operation: x = ReLU(x)^2""" + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE) + x0 = xindex + x = tl.load(in_ptr + (x0), None) + x = tl.maximum(x, 0.0).to(tl.float32) + x = tl.cast(x * x, dtype=out_ptr.dtype.element_ty, fp_downcast_rounding='rtne') + tl.store(out_ptr + (x0), x, None) + + +def rwkv_mix_torch(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor): + if x_prev.dim() == 2: + x_prev = x_prev.unsqueeze(1) # (batch_size, 1, hidden_dim) + xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x + k = x.addcmul(xx, x_k) + return k + + +def rwkv_relu_and_square_torch(x: torch.Tensor): + return torch.relu(x) ** 2 + + +def rwkv_mix_fwd(x, x_prev, x_k): + has_batch = x.dim() == 3 + + if has_batch: + batch_size, token_length, hidden_dim = x.shape + else: + token_length, hidden_dim = x.shape + batch_size = 1 + x = x.unsqueeze(0) + x_prev = x_prev.unsqueeze(0) + + token_length = x.shape[1] + hidden_dim = x.shape[2] + total_elements = batch_size * token_length * hidden_dim + + output = torch.empty_like(x) + + def grid(meta): return ( + (total_elements + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], # grid_0 + 1, # grid_1 + 1 # grid_2 + ) + + rwkv_seq_mix_kernel[grid]( + x.contiguous(), + x_prev.contiguous(), + x_k.squeeze(), + output, + batch_size=batch_size, + token_length=token_length, + hidden_dim=hidden_dim, + ) + if not has_batch: + output = output.squeeze(0) + return output + + +def rwkv_relu_and_square_fwd(x: torch.Tensor, inplace: bool = True): + """ + Triton implementation of RWKV's ReLU and square operation + Args: + x: Input tensor + Returns: + Tensor after ReLU and square operations + """ + x = x.contiguous() + output = x if inplace else torch.empty_like(x) + + def grid(meta): return ( + (output.numel() + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], # grid_0 + 1, # grid_1 + 1 # grid_2 + ) + rwkv_channel_mixing_pow_and_relu[grid]( + x, + output, + BLOCK_SIZE=4096, + ) + + return output + + +@triton.jit +def relu_square_bwd_kernel( + out_ptr, + forward_input_ptr, + BLOCK_SIZE: tl.constexpr +): + """ReLU(x)^2 backward kernel + grad_input = grad_output * 2 * x if x > 0 else 0 + """ + pid = tl.program_id(0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + + x = tl.load(forward_input_ptr + offsets).to(tl.float32) + grad = tl.load(out_ptr + offsets).to(tl.float32) + + x = tl.maximum(x, 0.0) + + grad_input = grad * 2 * x + + tl.store(out_ptr + offsets, grad_input.to(out_ptr.dtype.element_ty)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': block_size}) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] + ], + key=['hidden_dim'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def rwkv_mix_bwd_kenel( + dk1_ptr0, + xk_ptr, + dx_ptr, + dx_prev_ptr, + batch_size, + token_length, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + + batch_idx = offsets // (token_length * hidden_dim) + seq_feat = offsets % (token_length * hidden_dim) + seq_idx = seq_feat // hidden_dim + feat_idx = seq_feat % hidden_dim + + is_valid = offsets < (batch_size * token_length * hidden_dim) + + dk1 = tl.load(dk1_ptr0 + offsets, mask=is_valid) + xk = tl.load(xk_ptr + feat_idx, mask=is_valid) + prod = dk1 * xk + + mask_next = seq_idx < (token_length - 1) + next_offset = offsets + hidden_dim + dk1_next = tl.load(dk1_ptr0 + next_offset, mask=mask_next & is_valid, other=0.0) + prod_next = dk1_next * xk + dx_val = dk1 - prod + tl.where(mask_next, prod_next, 0.0) + dx_val = tl.cast(dx_val, dtype=dx_ptr.dtype.element_ty, fp_downcast_rounding='rtne') + tl.store(dx_ptr + offsets, dx_val, mask=is_valid) + + dx_prev_offset = batch_idx * hidden_dim + feat_idx + is_first_step = seq_idx == 0 + + tl.store( + dx_prev_ptr + dx_prev_offset, + tl.cast(prod, dtype=dx_prev_ptr.dtype.element_ty), + mask=is_first_step + ) + + +@torch.compile(fullgraph=True) +def compute_x_k_grad(dk1, x, x_prev): + """ + Args: + dk1: (batch*seq_len, hidden_dim) + x: (batch, seq_len, hidden_dim) + x_prev: (batch, hidden_dim) or (batch, 1, hidden_dim) + """ + + if x_prev.dim() == 2: + x_prev = x_prev.unsqueeze(1) # (batch, 1, hidden_dim) + xx = torch.cat((x_prev, x[:, :-1, :]), dim=1) - x # (batch, seq_len, hidden_dim) + + # (hidden_dim,) --> (1, 1, hidden_dim) + grad_x_k = (dk1 * xx.reshape(-1, x.shape[2])).sum(dim=0).view(1, 1, -1) + return grad_x_k + + +def rwkv_channel_mixing_bwd(grad_output, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, inplace=True): + batch_size = x.shape[0] if x.dim() == 3 else 1 + seq_len, n_embd = x.shape[-2], x.shape[-1] + + dV = k.transpose(-2, -1) @ grad_output + dk = grad_output @ value_weight.transpose(-2, -1) + + BLOCK_SIZE = 4096 + grid = ((dk.numel() + BLOCK_SIZE - 1) // BLOCK_SIZE,) + relu_square_bwd_kernel[grid]( + dk, + k1_K, + BLOCK_SIZE=BLOCK_SIZE + ) + + dK = k1.transpose(-2, -1) @ dk + dk1 = dk @ key_weight.transpose(-2, -1) + dk1 = dk1.view(-1, n_embd).contiguous() + + dk_reduced = compute_x_k_grad(dk1, x, x_prev) + dx_prev = torch.empty_like(x_prev) if not inplace else x_prev + dx = torch.empty_like(x) if not inplace else x + + def grid(meta): return ((batch_size * seq_len * n_embd + meta['BLOCK_SIZE'] - 1) // meta['BLOCK_SIZE'], 1, 1) + rwkv_mix_bwd_kenel[grid]( + dk1, + x_k.squeeze(), + dx, + dx_prev, + batch_size, + seq_len, + n_embd, + ) + # dx_prev.shape batch_size, seq_len, n_embd + return dx, dx_prev, dk_reduced, dK, dV + + +class Rwkv7ChannelMixing(torch.autograd.Function): + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, x, x_prev, x_k, key_weight, value_weight, inplace: bool = True): + k1 = rwkv_mix_fwd(x, x_prev, x_k) + k1_K = k1 @ key_weight + k = rwkv_relu_and_square_fwd(k1_K, inplace=True) + ctx.save_for_backward(x, x_prev, x_k, key_weight, value_weight) + ctx.inplace = inplace + return k @ value_weight + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, dkv): + x, x_prev, x_k, key_weight, value_weight = ctx.saved_tensors + k1 = rwkv_mix_fwd(x, x_prev, x_k) + k1_K = k1 @ key_weight + k = rwkv_relu_and_square_fwd(k1_K, inplace=False) + dx, dx_prev, dk_reduced, dK, dV = rwkv_channel_mixing_bwd( + dkv, x, x_prev, x_k, key_weight, value_weight, k1, k1_K, k, ctx.inplace) + return dx, dx_prev, dk_reduced, dK, dV, None + + +def channel_mixing_rwkv7(x: torch.Tensor, x_prev: torch.Tensor, x_k: torch.Tensor, + key_weight: torch.Tensor, value_weight: torch.Tensor, inplace: bool = True): + assert x.dim() == 3 + + return Rwkv7ChannelMixing.apply(x, x_prev, x_k, key_weight, value_weight, inplace), x[-1, :] + + +def channel_mixing_rwkv7_torch(x, x_prev, x_k, key_weight, value_weight): + k1 = rwkv_mix_torch(x, x_prev, x_k) + k1_K = k1 @ key_weight + k = rwkv_relu_and_square_torch(k1_K) + return k @ value_weight, x[-1, :] diff --git a/fla/ops/rwkv7/chunk.py b/fla/ops/rwkv7/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..956c458974316d1f83121bd71ef1ec433cb6cdde --- /dev/null +++ b/fla/ops/rwkv7/chunk.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch + +from fla.ops.generalized_delta_rule import chunk_dplr_delta_rule + + +def chunk_rwkv7( + r: torch.Tensor, + w: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = 1.0, + initial_state: torch.Tensor = None, + output_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +): + """ + Args: + r (torch.Tensor): + r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + w (torch.Tensor): + log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + scale (float): + scale of the attention. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (bool): + whether to use head first. Recommended to be False to avoid extra transposes. + """ + return chunk_dplr_delta_rule( + q=r, + k=k, + v=v, + a=a, + b=b, + gk=w, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + head_first=head_first + ) diff --git a/fla/ops/rwkv7/fused_addcmul.py b/fla/ops/rwkv7/fused_addcmul.py new file mode 100644 index 0000000000000000000000000000000000000000..43fe0711d3a555b57bdeb1365be7ae6f2bab8477 --- /dev/null +++ b/fla/ops/rwkv7/fused_addcmul.py @@ -0,0 +1,262 @@ +# -*- coding: utf-8 -*- + +import logging +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import check_pytorch_version, device, input_guard, use_cuda_graph + +logger = logging.getLogger(__name__) + +if not check_pytorch_version('2.4'): + logger.warning('PyTorch < 2.4 detected - computations may be slower due to lack of optimizations') + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['hidden_dim'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def fused_addcmul_fwd_kernel( + hidden_ptr, + x_ptr, + ixr_ptr, + ixw_ptr, + ixk_ptr, + ixv_ptr, + ixa_ptr, + ixg_ptr, + oxr_ptr, + oxw_ptr, + oxk_ptr, + oxv_ptr, + oxa_ptr, + oxg_ptr, + use_xg: tl.constexpr, + xnumel, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + + valid_indices = xnumel - xoffset + xmask = xindex < (xoffset + valid_indices) + x0 = xindex % hidden_dim + b_hiddn = tl.load(hidden_ptr + (xindex), xmask, other=0.).to(tl.float32) + b_x = tl.load(x_ptr + (xindex), xmask, other=0.).to(tl.float32) + b_ixr = tl.load(ixr_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixw = tl.load(ixw_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixk = tl.load(ixk_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixv = tl.load(ixv_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixa = tl.load(ixa_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_oxr = b_hiddn + b_x * b_ixr + b_oxw = b_hiddn + b_x * b_ixw + b_oxk = b_hiddn + b_x * b_ixk + b_oxv = b_hiddn + b_x * b_ixv + b_oxa = b_hiddn + b_x * b_ixa + + tl.store(oxr_ptr + (xindex), b_oxr.to(oxr_ptr.dtype.element_ty), xmask) + tl.store(oxw_ptr + (xindex), b_oxw.to(oxw_ptr.dtype.element_ty), xmask) + tl.store(oxk_ptr + (xindex), b_oxk.to(oxk_ptr.dtype.element_ty), xmask) + tl.store(oxv_ptr + (xindex), b_oxv.to(oxv_ptr.dtype.element_ty), xmask) + tl.store(oxa_ptr + (xindex), b_oxa.to(oxa_ptr.dtype.element_ty), xmask) + + if use_xg: + b_ixg = tl.load(ixg_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_oxg = b_hiddn + b_x * b_ixg + tl.store(oxg_ptr + (xindex), b_oxg.to(oxg_ptr.dtype.element_ty), xmask) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE': block_size}, num_warps=num_warps) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['hidden_dim'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def addcmul_bwd_kernel1( + ixr_ptr, + ixw_ptr, + ixk_ptr, + ixv_ptr, + ixa_ptr, + ixg_ptr, + dxr_ptr, + dxw_ptr, + dxk_ptr, + dxv_ptr, + dxa_ptr, + dxg_ptr, + ghidden_ptr, + gx_ptr, + use_xg: tl.constexpr, + xnumel, + hidden_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + xoffset = tl.program_id(0) * BLOCK_SIZE + xindex = xoffset + tl.arange(0, BLOCK_SIZE)[:] + + valid_indices = xnumel - xoffset + xmask = xindex < (xoffset + valid_indices) + x0 = xindex % hidden_dim + + b_dxr = tl.load(dxr_ptr + (xindex), None).to(tl.float32) + b_dxw = tl.load(dxw_ptr + (xindex), None).to(tl.float32) + b_dxk = tl.load(dxk_ptr + (xindex), None).to(tl.float32) + b_dxv = tl.load(dxv_ptr + (xindex), None).to(tl.float32) + b_dxa = tl.load(dxa_ptr + (xindex), None).to(tl.float32) + b_ixr = tl.load(ixr_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixw = tl.load(ixw_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_iwk = tl.load(ixk_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixv = tl.load(ixv_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + b_ixa = tl.load(ixa_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + + if use_xg: + b_dxg = tl.load(dxg_ptr + (xindex), None).to(tl.float32) + b_ixg = tl.load(ixg_ptr + (x0), eviction_policy='evict_last').to(tl.float32) + g_hidden = b_dxr + b_dxw + b_dxk + b_dxv + b_dxa + b_dxg + g_x = b_dxr * b_ixr + b_dxw * b_ixw + b_dxk * b_iwk + b_dxv * b_ixv + b_dxa * b_ixa + b_dxg * b_ixg + else: + g_hidden = b_dxr + b_dxw + b_dxk + b_dxv + b_dxa + g_x = b_dxr * b_ixr + b_dxw * b_ixw + b_dxk * b_iwk + b_dxv * b_ixv + b_dxa * b_ixa + + tl.store(ghidden_ptr + (xindex), g_hidden.to(ghidden_ptr.dtype.element_ty), xmask) + tl.store(gx_ptr + (xindex), g_x.to(gx_ptr.dtype.element_ty), xmask) + + +def addcmul_bwd1(d_oxr, d_oxw, d_oxk, d_oxv, d_oxa, d_oxg, x_r, x_w, x_k, x_v, x_a, x_g, hidden_states, xx, use_xg): + d_hiddn = torch.empty_like(hidden_states) + d_xx = torch.empty_like(xx) + numel = hidden_states.numel() + def grid(meta): return (triton.cdiv(meta['xnumel'], meta['BLOCK_SIZE']),) + addcmul_bwd_kernel1[grid]( + ixr_ptr=x_r, + ixw_ptr=x_w, + ixk_ptr=x_k, + ixv_ptr=x_v, + ixa_ptr=x_a, + ixg_ptr=x_g, + dxr_ptr=d_oxr, + dxw_ptr=d_oxw, + dxk_ptr=d_oxk, + dxv_ptr=d_oxv, + dxa_ptr=d_oxa, + dxg_ptr=d_oxg, + ghidden_ptr=d_hiddn, + gx_ptr=d_xx, + use_xg=use_xg, + xnumel=numel, + hidden_dim=hidden_states.size(-1), + ) + return d_hiddn, d_xx + + +@torch.compile(fullgraph=True) +def addcmul_bwd2(d_oxr, d_oxw, d_oxk, d_oxv, d_oxa, d_oxg, xx, use_xg: bool): + g_xr = (d_oxr * xx).sum(dim=(0, 1), keepdim=True) + g_xw = (d_oxw * xx).sum(dim=(0, 1), keepdim=True) + g_xk = (d_oxk * xx).sum(dim=(0, 1), keepdim=True) + g_xv = (d_oxv * xx).sum(dim=(0, 1), keepdim=True) + g_xa = (d_oxa * xx).sum(dim=(0, 1), keepdim=True) + g_xg = (d_oxg * xx).sum(dim=(0, 1), keepdim=True) if use_xg else None + return g_xr, g_xw, g_xk, g_xv, g_xa, g_xg + + +class Rwkv7FusedAddcmul(torch.autograd.Function): + @staticmethod + @input_guard + def forward(ctx, hidden_states, xx, + x_r, x_w, x_k, x_v, x_a, x_g, + num_elements + ): + oxr = torch.empty_like(hidden_states) + oxw = torch.empty_like(hidden_states) + oxk = torch.empty_like(hidden_states) + oxv = torch.empty_like(hidden_states) + oxa = torch.empty_like(hidden_states) + if x_g is not None: + use_xg = True + oxg = torch.empty_like(hidden_states) + else: + use_xg = False + oxg = None + ctx.save_for_backward(hidden_states, xx, + x_r, x_w, x_k, x_v, x_a, x_g) + ctx.use_xg = use_xg + + def grid(meta): return (triton.cdiv(meta['xnumel'], meta['BLOCK_SIZE']),) + fused_addcmul_fwd_kernel[grid]( + hidden_states, + xx, + x_r, + x_w, + x_k, + x_v, + x_a, + x_g, + oxr, + oxw, + oxk, + oxv, + oxa, + oxg, + use_xg, + num_elements, + hidden_states.size(-1), + ) + return oxr, oxw, oxk, oxv, oxa, oxg + + @staticmethod + @input_guard + def backward(ctx, dxr, + dxw, dxk, dxv, dxa, dxg): + hidden_states, xx, x_r, x_w, x_k, x_v, x_a, x_g = ctx.saved_tensors + + d_hiddn, d_xx = addcmul_bwd1(dxr, dxw, dxk, dxv, dxa, dxg, x_r, x_w, x_k, x_v, x_a, x_g, hidden_states, xx, ctx.use_xg) + + d_ixr, d_ixw, d_ixk, d_ixv, d_ixa, d_ixg = addcmul_bwd2(dxr, dxw, dxk, dxv, dxa, dxg, xx, ctx.use_xg) + + return d_hiddn, d_xx, d_ixr, d_ixw, d_ixk, d_ixv, d_ixa, d_ixg, None + + +def fused_addcmul_rwkv7( + hidden_states: torch.Tensor, + xx: torch.Tensor, + xr: torch.Tensor, + xw: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + xa: torch.Tensor, + xg: Optional[torch.Tensor] = None +): + num_elements = hidden_states.numel() + if num_elements < 16777216 and device == "cuda": + return torch_addcmul_rwkv7(hidden_states, xx, xr, xw, xk, xv, xa, xg) + else: + return Rwkv7FusedAddcmul.apply(hidden_states, xx, xr, xw, xk, xv, xa, xg, num_elements) + + +def torch_addcmul_rwkv7(hidden_states, xx, xr, xw, xk, xv, xa, xg=None): + oxr = torch.addcmul(hidden_states, xx, xr) + oxw = torch.addcmul(hidden_states, xx, xw) + oxk = torch.addcmul(hidden_states, xx, xk) + oxv = torch.addcmul(hidden_states, xx, xv) + oxa = torch.addcmul(hidden_states, xx, xa) + if xg is not None: + oxg = torch.addcmul(hidden_states, xx, xg) + return oxr, oxw, oxk, oxv, oxa, oxg + else: + return oxr, oxw, oxk, oxv, oxa, None diff --git a/fla/ops/rwkv7/fused_recurrent.py b/fla/ops/rwkv7/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce2d15aec9c6995f2df26992c89a29182f0169d --- /dev/null +++ b/fla/ops/rwkv7/fused_recurrent.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch + +from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule + + +def fused_recurrent_rwkv7( + r: torch.Tensor, + w: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = 1.0, + initial_state: torch.Tensor = None, + output_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +): + """ + Args: + r (torch.Tensor): + r of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + w (torch.Tensor): + log decay of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + k of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + v of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + a of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + b of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + scale (float): + scale of the attention. + initial_state (torch.Tensor): + initial state of shape `[B, H, K, V]` if cu_seqlens is None else `[N, H, K, V]` where N = len(cu_seqlens) - 1. + output_final_state (bool): + whether to output the final state. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (bool): + whether to use head first. Recommended to be False to avoid extra transposes. + """ + return fused_recurrent_dplr_delta_rule( + q=r, + k=k, + v=v, + a=a, + b=b, + gk=w, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + head_first=head_first + )