diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..4087bffe9a04984860364e29e5f819a7c8c0ebbf 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +checkpoint/step-60000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-70000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-90000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-100000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-20000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-80000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-10000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-50000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-1/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-30000/.metadata filter=lfs diff=lfs merge=lfs -text +checkpoint/step-40000/.metadata filter=lfs diff=lfs merge=lfs -text diff --git a/checkpoint/step-1/.metadata b/checkpoint/step-1/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..80c8ed4344a1145ac5e2df41d58fdba42bcdfdcc --- /dev/null +++ b/checkpoint/step-1/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7a1dc8097b7a6f7f04d8d3b1ac59dbca7331cd4f7554d465556a775cc8fb2a3 +size 1966399 diff --git a/checkpoint/step-10000/.metadata b/checkpoint/step-10000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..13ed7355e29678b7f0be98e758d6e252dd3ff343 --- /dev/null +++ b/checkpoint/step-10000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8230628b2b9318c77cc629a747246173e1f08994304d6eb6c277b937fe4a122 +size 1966605 diff --git a/checkpoint/step-100000/.metadata b/checkpoint/step-100000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..80be37df6ba42ae2f2ac645af753f129b81d52bb --- /dev/null +++ b/checkpoint/step-100000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20172fb8269576c2752f0c9c5571db66977472808e11f8a08a62ecebc8c1ca3f +size 1966953 diff --git a/checkpoint/step-20000/.metadata b/checkpoint/step-20000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..5e2f4d3d34fb6fa3a6855f708fd8e8f64e3c260a --- /dev/null +++ b/checkpoint/step-20000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6129edb991ed42f312a5437d565bb7921e331f0c8ee6e8d2b9271fd28d05350 +size 1966726 diff --git a/checkpoint/step-30000/.metadata b/checkpoint/step-30000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..c8f5840f6fef6b7959e75aa89854ea173fe62d2a --- /dev/null +++ b/checkpoint/step-30000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdd2f42f953103801f648d026a8c5473671a3970a42b90dbec3d61eb6887bba6 +size 1966842 diff --git a/checkpoint/step-40000/.metadata b/checkpoint/step-40000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..8ede82a7ecc3c4a82119301ac78e9f71896cc8d8 --- /dev/null +++ b/checkpoint/step-40000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1be5eabe3423f35d02413a59af0177d6cfc9e7a30fad1bf5d7c98ae6740fe503 +size 1966870 diff --git a/checkpoint/step-50000/.metadata b/checkpoint/step-50000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..699c11a7f34d83a347db34dc658414a125dc7e9e --- /dev/null +++ b/checkpoint/step-50000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b8de54a3bd58bec547d87c515963acf7b678e1976263b0c7d6106842d07b8ce +size 1966890 diff --git a/checkpoint/step-60000/.metadata b/checkpoint/step-60000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..689cebf8646ce7e81ebf557266024391fa99288e --- /dev/null +++ b/checkpoint/step-60000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cdf7894da9066d4b1b627693dec1f3df1211e7a894995740daaa7d77b7dd1985 +size 1966912 diff --git a/checkpoint/step-70000/.metadata b/checkpoint/step-70000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..dedd23bd634ed1699c42ca551d0c43aaf0db2a79 --- /dev/null +++ b/checkpoint/step-70000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:843a213824fe630f60d4c48810db333f8ffec4138a436c837403aba0704abad8 +size 1966934 diff --git a/checkpoint/step-80000/.metadata b/checkpoint/step-80000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..bf3408c6dc91b5ba368ee9b6d11f87809358b121 --- /dev/null +++ b/checkpoint/step-80000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65cc9b063d0a6a992190ccd8efccbf5142716eb5c2082a9ae581bbede3d15fd7 +size 1966952 diff --git a/checkpoint/step-90000/.metadata b/checkpoint/step-90000/.metadata new file mode 100644 index 0000000000000000000000000000000000000000..917cc2e618bea7bf1351fbc741dfd03fa3865829 --- /dev/null +++ b/checkpoint/step-90000/.metadata @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a25b7f564d152ef6984a3e23b67a896d57c0aae5a269437b077ef5bee760039 +size 1966952 diff --git a/fla/ops/__pycache__/__init__.cpython-311.pyc b/fla/ops/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c23873702ea4a7fcdaeb20eda55f54559b2d2552 Binary files /dev/null and b/fla/ops/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/attn/__pycache__/naive_softpick.cpython-311.pyc b/fla/ops/attn/__pycache__/naive_softpick.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52fe0c1059748ab58f9cbf487b7b9c8d56c2e57b Binary files /dev/null and b/fla/ops/attn/__pycache__/naive_softpick.cpython-311.pyc differ diff --git a/fla/ops/attn/__pycache__/parallel_rectified.cpython-311.pyc b/fla/ops/attn/__pycache__/parallel_rectified.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3049f0e5e9061611f8fe3c02d29024316889147b Binary files /dev/null and b/fla/ops/attn/__pycache__/parallel_rectified.cpython-311.pyc differ diff --git a/fla/ops/attn/__pycache__/parallel_softpick.cpython-311.pyc b/fla/ops/attn/__pycache__/parallel_softpick.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..908fa9f853bc86dc10678ecfb02557b6fbe9cdd4 Binary files /dev/null and b/fla/ops/attn/__pycache__/parallel_softpick.cpython-311.pyc differ diff --git a/fla/ops/gated_delta_rule/__init__.py b/fla/ops/gated_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad7f86639b3482c78768cf0511d2eb2650305e7f --- /dev/null +++ b/fla/ops/gated_delta_rule/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule" +] diff --git a/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc b/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce358ca6a0da2cacbc8bd839a32936e854412a9 Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc b/fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fcaf7d1d25993bd9c94860fd7253b529037e590 Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78001ab9d8011c1a70a5ac627284173cb997a425 Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-311.pyc b/fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6bb0e5c26fc3db4be3e27031b315a1f89ecd05b Binary files /dev/null and b/fla/ops/gated_delta_rule/__pycache__/wy_fast.cpython-311.pyc differ diff --git a/fla/ops/gated_delta_rule/chunk.py b/fla/ops/gated_delta_rule/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..abbb52a56fbaf62a4c818c32217dc8c95a0e2292 --- /dev/null +++ b/fla/ops/gated_delta_rule/chunk.py @@ -0,0 +1,392 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +from einops import rearrange + +from fla.modules.l2norm import l2norm_bwd, l2norm_fwd +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h +from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv_local, chunk_fwd_o +from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr, fwd_prepare_wy_repr, fwd_recompute_w_u +from fla.ops.utils import chunk_local_cumsum +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) + # obtain WY representation. u is actually the new v. + w, u, Aw, Au = fwd_prepare_wy_repr( + k=k, + v=v, + beta=beta, + g=g, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + + # obtain output + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return g, o, Aw, Au, final_state + + +def chunk_gated_delta_rule_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + h, v_new, _ = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=False, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dv = chunk_bwd_dv_local( + q=q, + k=k, + g=g, + do=do, + dh=None, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dh, dh0, dv = chunk_gated_delta_rule_bwd_dhu( + q=q, + k=k, + w=w, + g=g, + h0=initial_state, + dht=dht, + do=do, + dv=dv, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk, dw, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk2, dv, db, dg2 = bwd_prepare_wy_repr( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=dv, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk.add_(dk2) + dg.add_(dg2) + assert dg.dtype == torch.float32, "dg should be fp32" + dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, indices=indices, head_first=head_first) + return dq, dk, dv, db, dg, dh0 + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True, + use_qk_l2norm_in_kernel: bool = False + ): + chunk_size = 64 + q_orig = q + k_orig = k + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + g, o, Aw, Au, final_state = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size, + ) + ctx.save_for_backward(q_orig, k_orig, v, g, beta, Aw, Au, initial_state, offsets, indices) + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.head_first = head_first + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, g, beta, Aw, Au, initial_state, offsets, indices = ctx.saved_tensors + if ctx.use_qk_l2norm_in_kernel: + q, q_orig = l2norm_fwd(q), q + k, k_orig = l2norm_fwd(k), k + dq, dk, dv, db, dg, dh0 = chunk_gated_delta_rule_bwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + Aw=Aw, + Au=Au, + scale=ctx.scale, + initial_state=initial_state, + do=do, + dht=dht, + offsets=offsets, + indices=indices, + head_first=ctx.head_first, + chunk_size=ctx.chunk_size + ) + if ctx.use_qk_l2norm_in_kernel: + dq = l2norm_bwd(q_orig, dq) + dk = l2norm_bwd(k_orig, dk) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), db.to(beta), None, dh0, None, None, None, None + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + head_first=False + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, H, T] if head_first=True, or [B, T, H] if head_first=False." + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if head_first: + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + beta, g = map(lambda x: rearrange(x, 'b h t -> b t h'), (beta, g)) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "Scale must be positive." + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + False, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h v -> b h t v') + return o, final_state diff --git a/fla/ops/gated_delta_rule/fused_recurrent.py b/fla/ops/gated_delta_rule/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..4c73b8a40be4044982d714f5922a7fc324a4fbb8 --- /dev/null +++ b/fla/ops/gated_delta_rule/fused_recurrent.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * H + i_h) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * H + i_h) * V + o_v + else: + p_beta = beta + bos * H + i_h + p_g = g + bos * H + i_h + p_o = o + ((i_k * all + bos) * H + i_h) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) + b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_q += H*K + p_k += H*K + p_o += H*V + p_v += H*V + p_g += H + p_beta += H * (V if IS_BETA_HEADWISE else 1) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + use_qk_l2norm_in_kernel: bool = False, + offsets: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if output_final_state: + final_state = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + final_state = None + + grid = (NK, NV, N * H) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + offsets=offsets, + scale=scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + offsets=offsets + ) + + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass is not implemented yet and we do not have plans to implement it " + "because we haven't figured out how to compute dg without materializing the full " + "hidden states for all time steps." + ) + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + g (decays) of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `(B, H, T)`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, H, device='cuda')) + >>> beta = torch.rand(B, T, H, device='cuda').sigmoid() + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if head_first: + raise RuntimeError( + "Sequences with variable lengths are not supported for head-first mode" + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + if head_first: + q, k, v, g, beta = map(lambda x: rearrange(x, 'b h t ... -> b t h ...'), (q, k, v, g, beta)) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, 'b t h v -> b h t v') + return o, final_state diff --git a/fla/ops/gated_delta_rule/wy_fast.py b/fla/ops/gated_delta_rule/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..f80b2251f32e60dda83735f74183546b15ef45a0 --- /dev/null +++ b/fla/ops/gated_delta_rule/wy_fast.py @@ -0,0 +1,620 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import safe_exp +from fla.utils import check_shared_mem + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'BK', 'BC', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + k, + g, + beta, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_Aw = tl.zeros([BC, BC], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_Aw += tl.dot(b_kb, tl.trans(b_k)) + + b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0) + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + + b_g = tl.load(p_g, boundary_check=(0,)) + b_Au = b_Aw * safe_exp(b_g[:, None] - b_g[None, :]) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0) + b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0) + b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i) + b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i) + b_Aw = tl.where(mask[:, None], b_aw, b_Aw) + b_Au = tl.where(mask[:, None], b_au, b_Au) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + if HEAD_FIRST: + p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + else: + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + tl.store(p_Aw, b_Aw.to(p_Aw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au, b_Au.to(p_Au.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'BK', 'BC', 'USE_OFFSETS', 'HEAD_FIRST'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + k, + g, + beta, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_Aw = tl.zeros([BC, BC], dtype=tl.float32) + b_Aw2 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aw3 = tl.zeros([BC, BC], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + p_beta2 = tl.make_block_ptr(beta + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) + p_beta2 = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) + + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_beta2 = tl.load(p_beta2, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)) + b_kb2 = (b_k2 * b_beta2[:, None]).to(b_k2.dtype) + b_Aw += tl.dot(b_kb, tl.trans(b_k)) + b_Aw2 += tl.dot(b_kb2, tl.trans(b_k2)) + b_Aw3 += tl.dot(b_kb2, tl.trans(b_k)) + + b_Aw = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw, 0) + b_Aw2 = -tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_Aw2, 0) + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + i_bh*T, (T,), (1,), (i_t * BT + BC,), (BC,), (0,)) + else: + p_g = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT,), (BC,), (0,)) + p_g2 = tl.make_block_ptr(g + bos*H + i_h, (T,), (H,), (i_t * BT + BC,), (BC,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g2 = tl.load(p_g2, boundary_check=(0,)) + + mask_c = tl.arange(0, BC)[:, None] >= tl.arange(0, BC)[None, :] + mask_g = i_t * BT + tl.arange(0, BC) < T + mask_g2 = i_t * BT + BC + tl.arange(0, BC) < T + + b_Au = tl.where(mask_g[None, :] & mask_c, b_Aw * safe_exp(b_g[:, None] - b_g[None, :]), 0) + b_Au2 = tl.where(mask_g2[None, :] & mask_c, b_Aw2 * safe_exp(b_g2[:, None] - b_g2[None, :]), 0) + b_Au3 = tl.where(mask_g[None, :], b_Aw3 * safe_exp(b_g2[:, None] - b_g[None, :]), 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_aw = tl.sum(tl.where(mask[:, None], b_Aw, 0), 0) + b_aw2 = tl.sum(tl.where(mask[:, None], b_Aw2, 0), 0) + b_au = tl.sum(tl.where(mask[:, None], b_Au, 0), 0) + b_au2 = tl.sum(tl.where(mask[:, None], b_Au2, 0), 0) + b_aw = b_aw + tl.sum(b_aw[:, None] * b_Aw, 0) * (tl.arange(0, BC) < i) + b_aw2 = b_aw2 + tl.sum(b_aw2[:, None] * b_Aw2, 0) * (tl.arange(0, BC) < i) + b_au = b_au + tl.sum(b_au[:, None] * b_Au, 0) * (tl.arange(0, BC) < i) + b_au2 = b_au2 + tl.sum(b_au2[:, None] * b_Au2, 0) * (tl.arange(0, BC) < i) + b_Aw = tl.where(mask[:, None], b_aw, b_Aw) + b_Aw2 = tl.where(mask[:, None], b_aw2, b_Aw2) + b_Au = tl.where(mask[:, None], b_au, b_Au) + b_Au2 = tl.where(mask[:, None], b_au2, b_Au2) + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_Aw += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Aw2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + # improve precision by disallowing tf32. + b_Aw3 = -tl.dot(tl.dot(b_Aw2, b_Aw3, allow_tf32=False), b_Aw, allow_tf32=False) + b_Au += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_Au3 = -tl.dot(tl.dot(b_Au2, b_Au3, allow_tf32=False), b_Au, allow_tf32=False) + + if HEAD_FIRST: + p_Aw1 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Aw2 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Aw3 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Aw4 = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + p_Au1 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au2 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Au3 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Au4 = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_Aw1 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Aw2 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Aw3 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Aw4 = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + p_Au1 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_Au2 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_Au3 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_Au4 = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + tl.store(p_Aw1, b_Aw.to(p_Aw1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw2, b_Aw2.to(p_Aw2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw3, b_Aw3.to(p_Aw3.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aw4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Aw4.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au1, b_Au.to(p_Au1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au2, b_Au2.to(p_Au2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au3, b_Au3.to(p_Au3.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Au4, tl.zeros([BC, BC], dtype=tl.float32).to(p_Au4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def fwd_recompute_w_u_kernel( + k, + v, + beta, + w, + u, + Aw, + Au, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_Au = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_Au = tl.load(p_Au, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_Au, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + b_Au = None + if HEAD_FIRST: + p_Aw = tl.make_block_ptr(Aw + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aw = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_Aw = tl.load(p_Aw, boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) + b_w = tl.dot(b_Aw, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + # bf16 should be good enough. + Aw = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) + Au = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=k.dtype) + + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + fwd_fn[(NT, B*H)]( + k=k, + g=g, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_recompute_w_u( + k=k, + v=v, + beta=beta, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, Aw, Au + + +def fwd_recompute_w_u( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + u = torch.empty_like(v) + w = torch.empty_like(k) + fwd_recompute_w_u_kernel[(NT, B*H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + Aw=Aw, + Au=Au, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'] +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + k, + v, + beta, + g, + Aw, + Au, + dw, + du, + dk, + dv, + dbeta, + dg, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_dbeta = tl.zeros([BT], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_beta = tl.make_block_ptr(beta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(Aw + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) + b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) + b_dk = b_dk_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) + + if HEAD_FIRST: + p_A = tl.make_block_ptr(Au + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + else: + p_A = tl.make_block_ptr(Au + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA2 = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA2 += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) + b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) + b_dv = b_dv_beta * b_beta[:, None] + b_dbeta += tl.sum(b_dv_beta * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA2, 0) + b_dA2 = tl.dot(b_dA2.to(b_A.dtype), b_A) + b_dA2 = tl.dot(b_A, b_dA2.to(b_A.dtype)) + b_dA2 = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA2, 0).to(k.dtype.element_ty) + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_g = tl.make_block_ptr(g + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_dA2 *= safe_exp(b_g[:, None] - b_g[None, :]) + b_dA += b_dA2 + b_dA = b_dA.to(k.dtype.element_ty) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.load(p_dk, boundary_check=(0, 1)) + b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) + b_A += tl.dot(b_k_beta, tl.trans(b_k)) + b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) + b_dbeta += tl.sum(b_dk_beta * b_k, 1) + b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) + b_dk += b_dk_beta * b_beta[:, None] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + b_dA2 *= b_A + b_dg = tl.sum(b_dA2, axis=1) - tl.sum(b_dA2, axis=0) + if HEAD_FIRST: + p_dg = tl.make_block_ptr(dg + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_dg = tl.make_block_ptr(dg + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_dbeta = tl.make_block_ptr(dbeta + (bos*H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) + + +def bwd_prepare_wy_repr( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + Aw: torch.Tensor, + Au: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + dk = torch.empty_like(k) + dv = torch.empty_like(v) + dbeta = torch.empty_like(beta) + dg = torch.empty_like(g) + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + g=g, + Aw=Aw, + Au=Au, + dw=dw, + du=du, + dk=dk, + dv=dv, + dbeta=dbeta, + dg=dg, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dk, dv, dbeta, dg diff --git a/fla/ops/generalized_delta_rule/README.md b/fla/ops/generalized_delta_rule/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f96c22f44a51ad3e6fdeb824eb2aded660223600 --- /dev/null +++ b/fla/ops/generalized_delta_rule/README.md @@ -0,0 +1,37 @@ +# Generalized Delta Rule + +In delta rule we have the recurrence: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}-\beta_t \mathbf{k}_t\mathbf{k}_t^T) + \beta_t \mathbf{v}_t\mathbf{k}_t^T +``` + +This repository implements a delta rule variant where $\mathbf{I}$ is not necessarily an identity matrix; $\mathbf{k}_t$ in $\mathbf{I} - \beta_t \mathbf{k}_t\mathbf{k}_t^T$ might be different from input $\mathbf{k}_t$ in $\mathbf{v}_t\mathbf{k}_t^T$. + +## IPLR (Identity Plus Low Rank) + +The first variant is IPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{I}+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +When $\mathbf{a}_t = -\beta_t \mathbf{k}_t$, $\mathbf{b}_t = \mathbf{k}_t$, $\mathbf{v}_t= \beta_t \mathbf{v}_t$, we recover the original delta rule. Since here the transition matrix is identity-plus-low-rank, we refer to this variant as IPLR. + +### Numerical Stability + +$\mathbf{a}_t$ and $\mathbf{b}_t$ must be in opposite directions, that is, $\mathbf{b}_t = \lambda_t \mathbf{a}_t$ where $\lambda_t < 0$. For an understanding of why this is necessary, you can derive the eigenvalues of the transition matrix. + +## DPLR (Diagonal Plus Low Rank) + +The second variant is DPLR, where we have: + +```math +\mathbf{S}_t = \mathbf{S}_{t-1}(\mathbf{D}_t+\mathbf{a}_t\mathbf{b}_t^T) + \mathbf{v}_t\mathbf{k}_t^T +``` + +Here, $\mathbf{I}$ is replaced by a diagonal matrix $\mathbf{D}_t$. This transition matrix structure has been utilized in RWKV7. + +## Efficient Chunkwise Implementation + +For detailed information about efficient chunkwise implementation, please refer to our [technical note](https://drive.google.com/file/d/1rJbO3dU4fe7OKG3w7Yg058z_BNIuavNF/view?usp=sharing). diff --git a/fla/ops/generalized_delta_rule/__init__.py b/fla/ops/generalized_delta_rule/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b4155a215ca8c44ea45d6b151b1e584872ed6c --- /dev/null +++ b/fla/ops/generalized_delta_rule/__init__.py @@ -0,0 +1,9 @@ +from .dplr import chunk_dplr_delta_rule, fused_recurrent_dplr_delta_rule +from .iplr import chunk_iplr_delta_rule, fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule', + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc b/fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0900d500d82f37c0ceb7e2ff668d02f86cf1e6ad Binary files /dev/null and b/fla/ops/generalized_delta_rule/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/chunk.py b/fla/ops/generalized_delta_rule/dplr/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..eac6af87a2bbc7c3dea56b1aea874c9347ccb5dd --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk.py @@ -0,0 +1,388 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton + +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.generalized_delta_rule.dplr.chunk_A_bwd import chunk_dplr_bwd_dqk_intra +from fla.ops.generalized_delta_rule.dplr.chunk_A_fwd import chunk_fwd_intra_dplr_fn +from fla.ops.generalized_delta_rule.dplr.chunk_h_bwd import chunk_dplr_bwd_dhu +from fla.ops.generalized_delta_rule.dplr.chunk_h_fwd import chunk_dplr_fwd_h +from fla.ops.generalized_delta_rule.dplr.chunk_o_bwd import chunk_dplr_bwd_dAu, chunk_dplr_bwd_dv, chunk_dplr_bwd_o +from fla.ops.generalized_delta_rule.dplr.chunk_o_fwd import chunk_dplr_fwd_o +from fla.ops.generalized_delta_rule.dplr.wy_fast_bwd import chunk_dplr_bwd_wy +from fla.ops.generalized_delta_rule.dplr.wy_fast_fwd import fwd_prepare_wy_repr +from fla.ops.rwkv6.chunk import chunk_rwkv6_fwd_cumsum +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_dplr_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first) + + A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + scale=scale, + offsets=offsets, + indices=indices, + chunk_size=BT, + head_first=head_first + ) + del ge + + # A_ab, A_ak, gi, ge torch.float32 + # A_qk, A_qb, qg, kg, ag, bg, dtype=q.dtype, eg: bf16 + w, u, _ = fwd_prepare_wy_repr( + ag=ag, + A_ab=A_ab, + A_ak=A_ak, + v=v, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del A_ab, A_ak + h, v_new, final_state = chunk_dplr_fwd_h( + kg=kg, + bg=bg, + v=v, + w=w, + u=u, + gk=gi, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del u, kg, bg, gi + + o = chunk_dplr_fwd_o( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del v_new, h, A_qk, A_qb + + return o, final_state + + +class ChunkDPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True + ): + chunk_size = 16 + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + o, final_state = chunk_dplr_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + gk=gk, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + ctx.save_for_backward(q, k, v, a, b, gk, initial_state) + ctx.head_first = head_first + ctx.offsets = offsets + ctx.indices = indices + ctx.scale = scale + ctx.chunk_size = chunk_size + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + q, k, v, a, b, gk, initial_state = ctx.saved_tensors + BT = ctx.chunk_size + head_first = ctx.head_first + offsets = ctx.offsets + indices = ctx.indices + scale = ctx.scale + + # ******* start recomputing everything, otherwise i believe the gpu memory will be exhausted ******* + gi, ge = chunk_rwkv6_fwd_cumsum(gk, BT, offsets=offsets, indices=indices, head_first=head_first) + + A_ab, A_qk, A_ak, A_qb, qg, kg, ag, bg = chunk_fwd_intra_dplr_fn( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + scale=scale, + offsets=offsets, + indices=indices, + chunk_size=BT, + head_first=head_first + ) + w, u, A_ab_inv = fwd_prepare_wy_repr( + ag=ag, + A_ab=A_ab, + A_ak=A_ak, + v=v, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del A_ab + h, v_new, _ = chunk_dplr_fwd_h( + kg=kg, + bg=bg, + v=v, + w=w, + u=u, + gk=gi, + initial_state=initial_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del u + # ******* end of recomputation ******* + # A_ak, A_ab_inv, gi, ge torch.float32 + # A_qk, A_qb, qg, kg, ag, bg, v_new dtype=q.dtype, eg: bf16 + + dv_new_intra, dA_qk, dA_qb = chunk_dplr_bwd_dAu( + v=v, + v_new=v_new, + do=do, + A_qb=A_qb, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + + dh, dh0, dv_new = chunk_dplr_bwd_dhu( + qg=qg, + bg=bg, + w=w, + gk=gi, + h0=initial_state, + dht=dht, + do=do, + dv=dv_new_intra, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + + dv = chunk_dplr_bwd_dv( + A_qk=A_qk, + kg=kg, + do=do, + dh=dh, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del A_qk + + dqg, dkg, dw, dbg, dgk_last = chunk_dplr_bwd_o( + k=kg, + b=bg, + v=v, + v_new=v_new, + do=do, + h=h, + dh=dh, + dv=dv_new, + w=w, + gk=gi, + offsets=offsets, + indices=indices, + chunk_size=BT, + scale=scale, + head_first=head_first, + ) + del v_new + + dA_ab, dA_ak, dv, dag = chunk_dplr_bwd_wy( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + v=v, + ag=ag, + dw=dw, + du=dv_new, + dv0=dv, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + del A_ak + + dq, dk, da, db, dgk = chunk_dplr_bwd_dqk_intra( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + dAqk=dA_qk, + dAqb=dA_qb, + dAak=dA_ak, + dAab=dA_ab, + dgk_last=dgk_last, + dqg=dqg, + dkg=dkg, + dag=dag, + dbg=dbg, + chunk_size=BT, + scale=scale, + head_first=head_first, + offsets=offsets, + indices=indices + ) + + return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), dgk.to(gk), None, dh0, None, None, None + + +@torch.compiler.disable +def chunk_dplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + gk (torch.Tensor): + gk of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. decay term in log space! + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + # assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." + # gk = gk.float() + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkDPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + gk, + scale, + initial_state, + output_final_state, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..8feac35b2f01c54999a8185cc881184310b67ede --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py @@ -0,0 +1,446 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, gather +from fla.utils import check_shared_mem, is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'NC', 'BT', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_intra( + q, + k, + a, + b, + gi, + ge, + dAqk, + dAqb, + dAak, + dAab, + dq, + dk, + da, + db, + dqg, + dkg, + dag, + dbg, + dgk, + dgk_offset, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + # offset calculation + ge += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + gi += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + q += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + a += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + b += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dq += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + da += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + db += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dqg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dag += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dkg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dbg += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dgk += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dgk_offset += i_bh * T * K if HEAD_FIRST else (bos*H + i_h) * K + dAqk += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT + dAqb += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT + dAak += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT + dAab += i_bh * T * BT if HEAD_FIRST else (bos*H + i_h) * BT + + stride_qk = K if HEAD_FIRST else H*K + stride_A = BT if HEAD_FIRST else H*BT + + p_ge = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_ge = tl.load(p_ge, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + b_da = tl.zeros([BC, BK], dtype=tl.float32) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + b_db = tl.zeros([BC, BK], dtype=tl.float32) + # intra chunk gradient calculation + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0)) + p_dAab = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0)) + p_dAqb = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0)) + p_dAak = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t*BT + i_i*BC, i_i*BC), (BC, BC), (1, 0)) + o_i = tl.arange(0, BC) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t*BT + i_i*BC, i_k*BK), (BC, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_b = tl.load(p_b, boundary_check=(0, 1)).to(tl.float32) + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_a = tl.load(p_a, boundary_check=(0, 1)).to(tl.float32) + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)).to(tl.float32) + b_dAab = tl.load(p_dAab, boundary_check=(0, 1)).to(tl.float32) + b_dAqb = tl.load(p_dAqb, boundary_check=(0, 1)).to(tl.float32) + b_dAak = tl.load(p_dAak, boundary_check=(0, 1)).to(tl.float32) + + # inter chunk gradient calculation + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + if i_i > 0: + p_gn = gi + (i_t * BT + i_i * BC - 1) * stride_qk + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + # [BK,] + for i_j in range(0, i_i): + p_kj = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_bj = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gkj = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqikj = tl.make_block_ptr(dAqk, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_dAaibj = tl.make_block_ptr(dAab, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_dAqibj = tl.make_block_ptr(dAqb, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_dAaikj = tl.make_block_ptr(dAak, (T, BT), (stride_A, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)) + b_bj = tl.load(p_bj, boundary_check=(0, 1)) + b_gkj = tl.load(p_gkj, boundary_check=(0, 1)) + tmp = exp(b_gn[None, :] - b_gkj) + b_kjg = b_kj * tmp + b_bjg = b_bj * tmp + # [BC, BC] + b_dAqikj = tl.load(p_dAqikj, boundary_check=(0, 1)) + b_dAaibj = tl.load(p_dAaibj, boundary_check=(0, 1)) + b_dAqibj = tl.load(p_dAqibj, boundary_check=(0, 1)) + b_dAaikj = tl.load(p_dAaikj, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dAqikj, b_kjg) + b_dq += tl.dot(b_dAqibj, b_bjg) + # [BC, BC] + b_da += tl.dot(b_dAaibj, b_bjg) + b_da += tl.dot(b_dAaikj, b_kjg) + b_dq *= exp(b_gi - b_gn[None, :]) + b_da *= exp(b_ge - b_gn[None, :]) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = gi + (min(i_t * BT + i_i * BC + BC, T) - 1)*stride_qk + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + m_j = (i_t * BT + i_j * BC + tl.arange(0, BC)) < T + p_qj = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_aj = tl.make_block_ptr(a, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gij = tl.make_block_ptr(gi, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gej = tl.make_block_ptr(ge, (T, K), (stride_qk, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqjki = tl.make_block_ptr(dAqk, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_dAajbi = tl.make_block_ptr(dAab, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_dAqjbi = tl.make_block_ptr(dAqb, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + p_dAajki = tl.make_block_ptr(dAak, (BT, T), (1, stride_A), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + b_qj = tl.load(p_qj, boundary_check=(0, 1)) + b_aj = tl.load(p_aj, boundary_check=(0, 1)) + b_gij = tl.load(p_gij, boundary_check=(0, 1)) + b_gej = tl.load(p_gej, boundary_check=(0, 1)) + b_gij = tl.where(m_j[:, None] & m_k, b_gij, float('-inf')) + b_gej = tl.where(m_j[:, None] & m_k, b_gej, float('-inf')) + b_qjg = b_qj * exp(b_gij - b_gn[None, :]) + b_ajg = b_aj * exp(b_gej - b_gn[None, :]) + # [BC, BC] + b_dAqjki = tl.load(p_dAqjki, boundary_check=(0, 1)) + b_dAajbi = tl.load(p_dAajbi, boundary_check=(0, 1)) + b_dAqjbi = tl.load(p_dAqjbi, boundary_check=(0, 1)) + b_dAajki = tl.load(p_dAajki, boundary_check=(0, 1)) + b_dk += tl.dot(b_dAqjki, b_qjg) + b_dk += tl.dot(b_dAajki, b_ajg) + b_db += tl.dot(b_dAqjbi, b_qjg) + b_db += tl.dot(b_dAajbi, b_ajg) + tmp = exp(b_gn[None, :] - b_gi) + b_dk *= tmp + b_db *= tmp + + # intra chunk gradient calculation + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # trick to index the block + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + col_idx = tl.full([BC, 1], j, dtype=tl.int16) + row_idx_bc = tl.full([1, BC], j, dtype=tl.int16) + # [1, BK] + b_kj = gather(b_k, row_idx, axis=0) + b_bj = gather(b_b, row_idx, axis=0) + b_gij = gather(b_gi, row_idx, axis=0) + b_gej = gather(b_ge, row_idx, axis=0) + b_qj = gather(b_q, row_idx, axis=0) + b_aj = gather(b_a, row_idx, axis=0) + # [BC, 1] + b_dAqk_j = gather(b_dAqk, col_idx, axis=1) + b_dAab_j = gather(b_dAab, col_idx, axis=1) + b_dAqb_j = gather(b_dAqb, col_idx, axis=1) + b_dAak_j = gather(b_dAak, col_idx, axis=1) + # [1, BC] -> [BC, 1] + b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None] + b_dA_qk_j = tl.sum(gather(b_dAqk, row_idx_bc, axis=0), 0)[:, None] + b_dA_ab_j = tl.sum(gather(b_dAab, row_idx_bc, axis=0), 0)[:, None] + b_dA_qb_j = tl.sum(gather(b_dAqb, row_idx_bc, axis=0), 0)[:, None] + b_dA_ak_j = tl.sum(gather(b_dAak, row_idx_bc, axis=0), 0)[:, None] + else: + mask_idx = tl.arange(0, BC) == j + b_kj = tl.sum(tl.where(mask_idx[:, None], b_k, 0), 0)[None, :] + b_bj = tl.sum(tl.where(mask_idx[:, None], b_b, 0), 0)[None, :] + b_gij = tl.sum(tl.where(mask_idx[:, None], b_gi, 0), 0)[None, :] + b_gej = tl.sum(tl.where(mask_idx[:, None], b_ge, 0), 0)[None, :] + b_dAqk_j = tl.sum(tl.where(mask_idx[None, :], b_dAqk, 0), 1)[:, None] + b_dAab_j = tl.sum(tl.where(mask_idx[None, :], b_dAab, 0), 1)[:, None] + b_dAqb_j = tl.sum(tl.where(mask_idx[None, :], b_dAqb, 0), 1)[:, None] + b_dAak_j = tl.sum(tl.where(mask_idx[None, :], b_dAak, 0), 1)[:, None] + b_dA_qk_j = tl.sum(tl.where(mask_idx[:, None], b_dAqk, 0), 0)[:, None] + b_dA_ab_j = tl.sum(tl.where(mask_idx[:, None], b_dAab, 0), 0)[:, None] + b_dA_qb_j = tl.sum(tl.where(mask_idx[:, None], b_dAqb, 0), 0)[:, None] + b_dA_ak_j = tl.sum(tl.where(mask_idx[:, None], b_dAak, 0), 0)[:, None] + # [1, BK] b_qj, b_aj + b_qj = tl.sum(tl.where(mask_idx[:, None], b_q, 0), 0)[None, :] + b_aj = tl.sum(tl.where(mask_idx[:, None], b_a, 0), 0)[None, :] + # tl.static_print(b_kj) + m_e = o_i[:, None] > j + m_i = o_i[:, None] >= j + tmp1 = exp(b_gi - b_gij) + tmp2 = exp(b_ge - b_gij) + b_dq += tl.where(m_i, b_dAqk_j * b_kj * tmp1, 0.) + b_dq += tl.where(m_i, b_dAqb_j * b_bj * tmp1, 0.) + b_da += tl.where(m_e, b_dAab_j * b_bj * tmp2, 0.) + b_da += tl.where(m_e, b_dAak_j * b_kj * tmp2, 0.) + + m_i = o_i[:, None] <= j + m_e = o_i[:, None] < j + tmp1 = exp(b_gij - b_gi) + tmp2 = exp(b_gej - b_gi) + b_dk += tl.where(m_i, b_dA_qk_j * b_qj * tmp1, 0.) + b_dk += tl.where(m_e, b_dA_ak_j * b_aj * tmp2, 0.) + b_db += tl.where(m_i, b_dA_qb_j * b_qj * tmp1, 0.) + b_db += tl.where(m_e, b_dA_ab_j * b_aj * tmp2, 0.) + # post processing + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_da = tl.make_block_ptr(da, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dqg = tl.make_block_ptr(dqg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dkg = tl.make_block_ptr(dkg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dbg = tl.make_block_ptr(dbg, (T, K), (stride_qk, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gn = gi + (min(i_t * BT + BT, T) - 1)*stride_qk + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_da += tl.load(p_dag, boundary_check=(0, 1)) * exp(b_ge) + b_dq += tl.load(p_dqg, boundary_check=(0, 1)) * exp(b_gi) * scale + tmp = exp(b_gn[None, :] - b_gi) + b_dk += tl.load(p_dkg, boundary_check=(0, 1)) * tmp + b_db += tl.load(p_dbg, boundary_check=(0, 1)) * tmp + tl.store(p_dq, (b_dq).to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_da, b_da.to(p_da.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + b_dgk = b_dq * b_q + b_da * b_a - b_dk * b_k - b_db * b_b + b_dgk_offset = b_da * b_a + tl.store(p_dgk, b_dgk.to(p_dgk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dgk_offset, b_dgk_offset.to(p_dgk_offset.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in [32, 64] + ], + key=['BK', 'BT', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_dgk_kernel( + dgk, + dgk_offset, + dgk_last, + dgk_output, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + T = eos - bos + stride_qk = K if HEAD_FIRST else H * K + dgk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dgk_offset += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dgk_last += ((i_bh * NT + i_t) * K) if HEAD_FIRST else (i_tg * H + i_h) * K + dgk_output += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + p_dgk_last = dgk_last + tl.arange(0, BK) + i_k * BK + m_k = tl.arange(0, BK) + i_k * BK < K + b_dgk_last = tl.load(p_dgk_last, mask=m_k, other=0) + p_dgk_offset = tl.make_block_ptr(dgk_offset, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dgk = tl.make_block_ptr(dgk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dgk = tl.load(p_dgk, boundary_check=(0, 1)) + b_dgk_offset = tl.load(p_dgk_offset, boundary_check=(0, 1)) + # m_inv_cumsum = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]).to(tl.float32) + # b_dgk_cumsum = tl.dot(m_inv_cumsum, b_dgk, allow_tf32=False) + b_dgk_cumsum = tl.cumsum(b_dgk, 0, reverse=True) + b_dgk_cumsum += b_dgk_last[None, :] + b_dgk_cumsum -= b_dgk_offset + p_dgk_output = tl.make_block_ptr(dgk_output, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dgk_output, b_dgk_cumsum.to(p_dgk_output.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + dAqk: torch.Tensor, + dAqb: torch.Tensor, + dAak: torch.Tensor, + dAab: torch.Tensor, + dqg: torch.Tensor, + dkg: torch.Tensor, + dag: torch.Tensor, + dbg: torch.Tensor, + dgk_last: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + scale: float = 1.0, + chunk_size: int = 64, +): + if head_first: + B, H, T, K = q.shape + else: + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) if check_shared_mem() else min(32, triton.next_power_of_2(K)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + da = torch.empty_like(a) + db = torch.empty_like(b) + dgk = torch.empty_like(gi, dtype=torch.float) + dgk_offset = torch.empty_like(gi, dtype=torch.float) + + grid = (NK, NT * NC, B * H) + chunk_dplr_bwd_kernel_intra[grid]( + q=q, + k=k, + a=a, + b=b, + gi=gi, + ge=ge, + dAqk=dAqk, + dAqb=dAqb, + dAak=dAak, + dAab=dAab, + dq=dq, + dk=dk, + dgk=dgk, + dgk_offset=dgk_offset, + dqg=dqg, + dkg=dkg, + dag=dag, + dbg=dbg, + da=da, + db=db, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first, + GATHER_SUPPORTED=is_gather_supported + ) + + def grid2(meta): return (NT, triton.cdiv(K, meta['BK']), B * H) + dgk_output = torch.empty_like(dgk) + + chunk_dplr_bwd_dgk_kernel[grid2]( + dgk=dgk, + dgk_offset=dgk_offset, + dgk_last=dgk_last, + dgk_output=dgk_output, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + HEAD_FIRST=head_first + ) + return dq, dk, da, db, dgk_output diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..08518c203594e0f63f1e88b849ab688922e94f34 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, gather +from fla.utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BC', 'K'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_A_kernel_intra_sub_inter( + q, + k, + a, + b, + gi, # cumsum + ge, # before cumsum + Aqk, + Aqb, + Aab, + Aak, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + b_Aqk = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqb = tl.zeros([BC, BC], dtype=tl.float32) + b_Aab = tl.zeros([BC, BC], dtype=tl.float32) + b_Aak = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1)) + b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1)) + b_ag = b_a * exp(b_gq_e - b_gn[None, :]) + b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) + tmp = exp(b_gn[:, None] - b_gk) + b_kg = b_k * tmp + b_bg = b_b * tmp + # [BC, BC] using tf32 to improve precision here. + b_Aab += tl.dot(b_ag, b_bg) + b_Aak += tl.dot(b_ag, b_kg) + b_Aqk += tl.dot(b_qg, b_kg) + b_Aqb += tl.dot(b_qg, b_bg) + + if HEAD_FIRST: + p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BK', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_A_kernel_intra_sub_intra( + q, + k, + a, + b, + gi, + ge, + qg, + kg, + ag, + bg, + Aqk, + Aqb, + Aab, + Aak, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + last_idx = min((i_t+1) * BT, T) - 1 + if HEAD_FIRST: + o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK) + b_g_last = tl.load(p_g_last, mask=m_k, other=0) + + p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + else: + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK) + b_g_last = tl.load(p_g_last, mask=m_k, other=0) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = b_q * scale + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32) + b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32) + + # deal with decay term. + g_exp = exp(b_gi) + g_exp_inv = exp(-b_gi + b_g_last[None, :]) + b_qg = b_q * g_exp + b_kg = b_k * g_exp_inv + b_bg = b_b * g_exp_inv + b_ag = b_a * exp(b_ge) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # tl.debug_barrier() + + b_q = b_q.to(b_k.dtype) + # inner attn + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # a trick to index the j-th row of b_k, b_g, b_b + if GATHER_SUPPORTED: + row_idx = tl.full([1, BK], j, dtype=tl.int16) + # [1, BK] + b_k_j = gather(b_k, row_idx, axis=0) + b_gk_j = gather(b_gi, row_idx, axis=0) + b_b_j = gather(b_b, row_idx, axis=0) + else: + mask = tl.arange(0, BC) == j + b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :] + b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :] + b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :] + mask = tl.arange(0, BC) == j + tmp = exp(b_gi - b_gk_j) + b_A_qk = tl.sum(b_q * b_k_j * tmp, 1) + b_A_qk = tl.where(o_i >= j, b_A_qk, 0.) + b_A_qb = tl.sum(b_q * b_b_j * tmp, 1) + b_A_qb = tl.where(o_i >= j, b_A_qb, 0.) + tmp2 = exp(b_ge - b_gk_j) + b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1) + b_A_ak = tl.where(o_i > j, b_A_ak, 0.) + b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1) + b_A_ab = tl.where(o_i > j, b_A_ab, 0.) + tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) + + +def chunk_fwd_intra_dplr_fn( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gi: torch.Tensor, + ge: torch.Tensor, + scale: float, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, +): + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype) + Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype) + # involving matrix inverse and it'd be better to use float here. + Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + + chunk_dplr_fwd_A_kernel_intra_sub_inter[grid]( + q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak, + offsets=offsets, indices=indices, + scale=scale, + T=T, H=H, K=K, BT=BT, BC=BC, NC=NC, + HEAD_FIRST=head_first + ) + grid = (NT, NC, B * H) + BK = triton.next_power_of_2(K) + qg = torch.empty_like(q) + kg = torch.empty_like(k, dtype=q.dtype) + ag = torch.empty_like(a, dtype=q.dtype) + bg = torch.empty_like(b, dtype=q.dtype) + chunk_dplr_fwd_A_kernel_intra_sub_intra[grid]( + q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak, + qg=qg, kg=kg, ag=ag, bg=bg, + offsets=offsets, indices=indices, + scale=scale, + T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC, + GATHER_SUPPORTED=is_gather_supported + ) + return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..ffed7c5c1c8a8a0420c59ffcf32fb49453178797 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV', "V"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dhu( + qg, + bg, + w, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)) + + mask_k = tl.arange(0, BK) < K + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + if HEAD_FIRST: + p_qg = tl.make_block_ptr(qg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + else: + p_qg = tl.make_block_ptr(qg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BT] + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + # [BT, BK] + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv2 = b_dv + tl.dot(b_bg, b_dh.to(b_bg.dtype)) + tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_qg, b_do.to(b_qg.dtype)) + b_dh_tmp += tl.dot(b_w, b_dv2.to(b_qg.dtype)) + last_idx = min((i_t + 1) * BT, T) - 1 + if HEAD_FIRST: + bg_last = tl.load(gk + (i_nh * T + last_idx) * K + tl.arange(0, BK), mask=mask_k) + else: + bg_last = tl.load(gk + ((bos + last_idx) * H + i_h) * K + tl.arange(0, BK), mask=mask_k) + b_dh *= exp(bg_last)[:, None] + b_dh += b_dh_tmp + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dhu( + qg: torch.Tensor, + bg: torch.Tensor, + w: torch.Tensor, + gk: torch.Tensor, + h0: torch.Tensor, + dht: Optional[torch.Tensor], + do: torch.Tensor, + dv: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *qg.shape, do.shape[-1] + else: + B, T, H, K, V = *qg.shape, do.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + # H100 + if check_shared_mem('hopper', qg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', qg.device.index): # A100 + BV = 32 + BC = 32 + else: # Etc: 4090 + BV = 16 + BC = 16 + + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + + BC = min(BT, BC) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + dh = qg.new_empty(B, H, NT, K, V) + else: + dh = qg.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.zeros_like(dv) + + grid = (NK, NV, N * H) + chunk_dplr_bwd_kernel_dhu[grid]( + qg=qg, + bg=bg, + w=w, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dh, dh0, dv2 diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..b382d5905af9547a0626585453e43b01bf1b706c --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_h( + kg, + v, + w, + bg, + u, + v_new, + gk, + h, + h0, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + if HEAD_FIRST: + p_kg = tl.make_block_ptr(kg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_w = tl.make_block_ptr(w + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + else: + p_kg = tl.make_block_ptr(kg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_bg = tl.make_block_ptr(bg+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_w = tl.make_block_ptr(w+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_bg = tl.load(p_bg, boundary_check=(0, 1)) + b_v2 = tl.dot(b_w, b_h.to(b_w.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_kg, b_v) + b_hc += tl.dot(b_bg.to(b_hc.dtype), b_v2) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if HEAD_FIRST: + b_g_last = tl.load(gk + i_nh * T * K + last_idx * K + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32) + else: + b_g_last = tl.load(gk + (bos + last_idx) * H * K + i_h * K + + tl.arange(0, BK), mask=tl.arange(0, BK) < K).to(tl.float32) + b_h *= exp(b_g_last[:, None]) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_h( + kg: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + bg: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *kg.shape, u.shape[-1] + else: + B, T, H, K, V = *kg.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', kg.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', kg.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + h = kg.new_empty(B, H, NT, K, V) + else: + h = kg.new_empty(B, NT, H, K, V) + final_state = kg.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + chunk_dplr_fwd_kernel_h[grid]( + kg=kg, + v=v, + w=w, + bg=bg, + u=u, + v_new=v_new, + h=h, + gk=gk, + h0=initial_state, + ht=final_state, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, final_state diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..981901295b1b79ad881d7bd8600582e6c421f28a --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -0,0 +1,138 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BK_LIST + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_fwd_kernel_o( + qg, + v, + v_new, + A_qk, + A_qb, + h, + o, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_qg = tl.load(p_qg, boundary_check=(0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_o += tl.dot(b_qg, b_h) + + if HEAD_FIRST: + p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) + b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1)) + b_Aqk = tl.where(m_s, b_Aqk, 0) + b_Aqb = tl.where(m_s, b_Aqb, 0) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_fwd_o( + qg: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + A_qk: torch.Tensor, + A_qb: torch.Tensor, + h: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *qg.shape, v.shape[-1] + else: + B, T, H, K, V = *qg.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_dplr_fwd_kernel_o[grid]( + qg=qg, + v=v, + v_new=v_new, + A_qk=A_qk, + A_qb=A_qb, + h=h, + o=o, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o diff --git a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..17e7f3483a21de3f634ae70e58f8b7858810dda3 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -0,0 +1,292 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard, use_cuda_graph + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [16, 32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=['BK'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_dplr_delta_rule_fwd_kernel( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_a = a + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_b = b + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + o_k + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + p_o = o + i_nh * T*V + ((T-1) * V if REVERSE else 0) + o_v + + else: + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_a = a + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_b = b + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + o_k + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + p_o = o + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h = exp(b_gk)[None, :] * b_h + (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + b_o = tl.sum(b_h * b_q[None, :], axis=1) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_a += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_b += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + o_k[None, :] * V + o_v[:, None] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_dplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK = triton.next_power_of_2(K) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + ht = None + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), N * H) + fused_recurrent_dplr_delta_rule_fwd_kernel[grid]( + q, + k, + v, + a, + b, + gk, + o, + h0, + ht, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + return o, ht + + +class FusedRecurrentDPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = False + ): + o, ht = fused_recurrent_dplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + gk=gk, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + offsets=offsets, + head_first=head_first + ) + return o, ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + raise NotImplementedError( + "Backward pass for fused_recurrent_dplr_delta_rule is not implemented and will not be supported. " + "This kernel is only for inference. " + "For training, please use `chunk_dplr_delta_rule`." + ) + + +def fused_recurrent_dplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + gk: torch.Tensor, + scale: Optional[float] = 1.0, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` + a (torch.Tensor): + as of shape `[B, H, T, K]` + b (torch.Tensor): + bs of shape `[B, H, T, K]` + gk (torch.Tensor): + gk of shape `[B, H, T, K]` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If None, it will default to `1 / sqrt(K)`. Default: `1.0`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (Optional[torch.Tensor]): + Cumulative sequence lengths of shape `[N + 1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentDPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + gk, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/generalized_delta_rule/dplr/naive.py b/fla/ops/generalized_delta_rule/dplr/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ac253673e5361a375286347253f7d4e6f7a2f3 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + +# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T +# q, k, alpha, beta [B, H, L, D_K] +# v [B, H, L, D_V] + + +def dplr_recurrence(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta, gk = map(lambda x: x.float(), [q, k, v, beta, gk]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i] + _alpha = alpha[:, :, i].clone() + _beta = beta[:, :, i].clone() + _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] + S = S.clone() * gk[:, :, i].exp()[..., None] + _kv + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def dplr_chunkwise(q, k, v, alpha, beta, gk, initial_state=None, output_final_state=True, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v + assert l % chunk_size == 0 + + S = k.new_zeros(b, h, d_k, d_v).to(q) + if initial_state is not None: + S += initial_state + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, alpha, beta, gk = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', + c=chunk_size).float(), [q, k, v, alpha, beta, gk]) + + gk_cumsum = gk.cumsum(-2) + + # v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v + A_ab = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qk = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_ak = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + A_qb = torch.zeros(b, h, l // chunk_size, chunk_size, chunk_size).to(q.device) + + for i in range(chunk_size): + alpha_i = alpha[:, :, :, i, None] + q_i = q[:, :, :, i, None] + gk_i = gk_cumsum[:, :, :, i, None] + mask = (torch.arange(chunk_size) <= i).to(q.device) + attn_i = (gk_i - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_qk[:, :, :, i, :] = (q_i * k * attn_i).sum(-1).clone() + A_qb[:, :, :, i, :] = (q_i * beta * attn_i).sum(-1).clone() + mask = (torch.arange(chunk_size) < i).to(q.device) + # shift by one. + attn_i = (gk_i - gk[:, :, :, i, None] - gk_cumsum).masked_fill(~mask.unsqueeze(-1), float('-inf')).exp() + A_ab[:, :, :, i, :] = (alpha_i * beta * attn_i).sum(-1).clone() + A_ak[:, :, :, i, :] = (alpha_i * k * attn_i).sum(-1).clone() + + A_ab = A_ab + for i in range(1, chunk_size): + A_ab[..., i, :i] = A_ab[..., i, :i].clone() + (A_ab[..., i, :, None].clone() * A_ab[..., :, :i].clone()).sum(-2) + + A_ab = A_ab + torch.eye(chunk_size, dtype=torch.float, device=q.device) + u = A_ab @ (A_ak @ v) + w = A_ab @ ((gk_cumsum-gk).exp() * alpha) + + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] + v2_i = u_i + w_i @ S + + o_1 = A_qk[:, :, i] @ v_i + o_2 = A_qb[:, :, i] @ v2_i + o_3 = (q_i * gk_cumsum[:, :, i].exp()) @ S + o[:, :, i] = o_1 + o_2 + o_3 + decay = (gk_cumsum[:, :, i, -1, None] - gk_cumsum[:, :, i]).exp() + S = S*gk_cumsum[:, :, i, -1, :, None].exp() + (k_i * decay).transpose(-1, -2) @ v_i + \ + (beta_i * decay).transpose(-1, -2) @ v2_i + S = None if output_final_state is False else S + return rearrange(o, 'b h n c d -> b h (n c) d'), S diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..d9ff775184d4f1fa4472bb172da19fdd45553ed6 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, is_intel_alchemist, use_cuda_graph + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def bwd_prepare_wy_repr_kernel( + A_ab_inv, + A_ak, + ag, + v, + dw, + du, + dv, + dv0, + dag, + dAak, + dAab, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aak_t = tl.make_block_ptr(A_ak + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aak_t = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_Aab_inv_t = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_dAak = tl.make_block_ptr(dAak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dAab = tl.make_block_ptr(dAab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_ab_inv_t = tl.load(p_Aab_inv_t, boundary_check=(0, 1)) + b_A_ak_t = tl.load(p_Aak_t, boundary_check=(0, 1)) + b_A_ak_t = tl.where(tl.arange(0, BT)[:, None] < tl.arange(0, BT)[None, :], b_A_ak_t, 0) + b_A_ab_inv_t = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A_ab_inv_t, 0) + b_A_tmp_t = tl.dot(b_A_ak_t, b_A_ab_inv_t).to(v.dtype.element_ty) + b_dA_tmp = tl.zeros([BT, BT], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv0 = tl.make_block_ptr(dv0 + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA_tmp += tl.dot(b_du.to(b_v.dtype), tl.trans(b_v)) + b_dv0 = tl.load(p_dv0, boundary_check=(0, 1)) + b_dv = b_dv0 + tl.dot(b_A_tmp_t, b_du) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + b_dA_tmp = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_tmp, 0) + b_dA_ak = tl.dot(b_A_ab_inv_t, b_dA_tmp) + b_dA_ak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ak, 0) + tl.store(p_dAak, b_dA_ak, boundary_check=(0, 1)) + b_dA_ab_inv = tl.dot(b_dA_tmp, b_A_ak_t) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dag = tl.make_block_ptr(dag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + b_dA_ab_inv += tl.dot(b_dw, tl.trans(b_ag)) + b_dag = tl.dot(b_A_ab_inv_t.to(b_dw.dtype), b_dw) + tl.store(p_dag, b_dag.to(p_dag.dtype.element_ty), boundary_check=(0, 1)) + + # if we know dL/dA^(-1), for dL/dA, we can use the following formula: + # dL/dA = -(A^(-1))^T @ (dL/dA^(-1)) @ (A^(-1))^T + # in the fwd pass we use fwd substitution to calculate (I-lower(A_ab))^-1. + # denote A = I - lower(A_ab), B = A^-1 + # in the backward pass. + # dL/dA = -(B)^T @ (dL/dB) @ B^T + # dL/dA_ab = lower(B^T @ dL/dB @ B^T) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + b_dA_ab_inv = tl.dot(b_A_ab_inv_t, b_dA_ab_inv) + b_dA_ab_inv = tl.dot(b_dA_ab_inv, b_A_ab_inv_t) + b_dA_ab_inv = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA_ab_inv, 0) + tl.store(p_dAab, b_dA_ab_inv, boundary_check=(0, 1)) + + +def chunk_dplr_bwd_wy( + A_ab_inv: torch.Tensor, + A_ak: torch.Tensor, + v: torch.Tensor, + ag: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dv0: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + A_ab_inv, A_ak, v, ag, dw, du = map(lambda x: x.contiguous(), [A_ab_inv, A_ak, v, ag, dw, du]) + if head_first: + B, H, T, K, V = *dw.shape, du.shape[-1] + else: + B, T, H, K, V = *dw.shape, du.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(V), 32) + + dA_ab = torch.empty_like(A_ab_inv, dtype=torch.float) + dA_ak = torch.empty_like(A_ak, dtype=torch.float) + dv = torch.empty_like(v) + dag = torch.empty_like(ag) + + bwd_prepare_wy_repr_kernel[(NT, B * H)]( + A_ab_inv=A_ab_inv, + A_ak=A_ak, + ag=ag, + v=v, + dw=dw, + du=du, + dv=dv, + dv0=dv0, + dag=dag, + dAak=dA_ak, + dAab=dA_ab, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dA_ab, dA_ak, dv, dag diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py new file mode 100644 index 0000000000000000000000000000000000000000..3ef5ac298d5218a6a1c10087a2bafc547f03acff --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import gather +from fla.utils import is_gather_supported, use_cuda_graph + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + A_ab, + A_ab_inv, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, # placeholder, do not delete + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + if HEAD_FIRST: + p_Aab = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_Aab = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_Aab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A_ab = tl.load(p_Aab, boundary_check=(0, 1)) + b_A_ab = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A_ab, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A_ab, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A_ab, 0) * (tl.arange(0, BT) < i) + b_A_ab = tl.where(mask[:, None], b_a, b_A_ab) + b_A_ab += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + tl.store(p_Aab_inv, b_A_ab.to(p_Aab_inv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BC'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + A_ab, + A_ab_inv, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + GATHER_SUPPORTED: tl.constexpr = is_gather_supported +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + + p_A1 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A_ab + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv1 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A_inv2 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A_inv3 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv4 = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_A1 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A_ab + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv1 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A_inv2 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A_inv3 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A_inv4 = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + + b_A = tl.load(p_A1, boundary_check=(0, 1)) + b_A2 = tl.load(p_A2, boundary_check=(0, 1)) + b_A3 = tl.load(p_A3, boundary_check=(0, 1)) + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + if GATHER_SUPPORTED: + row_idx = tl.full([1, BC], i, dtype=tl.int16) + # [1, BK] -> [BK] + b_a = tl.sum(gather(b_A, row_idx, axis=0), 0) + b_a2 = tl.sum(gather(b_A2, row_idx, axis=0), 0) + else: + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + mask = tl.arange(0, BC) == i + # b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + # b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3), b_A) + # tl.debug_barrier() + tl.store(p_A_inv1, b_A.to(p_A_inv1.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv2, b_A2.to(p_A_inv2.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_A_inv3, b_A3.to(p_A_inv3.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + # causal mask + tl.store(p_A_inv4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A_inv4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def fwd_wu_kernel( + u, + w, + ag, + v, + A_ab_inv, + A_ak, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_A_ab_inv = tl.make_block_ptr(A_ab_inv + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_A_ak = tl.make_block_ptr(A_ak + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A_ab_inv = tl.make_block_ptr(A_ab_inv + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_A_ak = tl.make_block_ptr(A_ak + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_Aab_inv = tl.load(p_A_ab_inv, boundary_check=(0, 1)) + b_Aak = tl.load(p_A_ak, boundary_check=(0, 1)) + o_s = tl.arange(0, BT) + b_Aab_inv = tl.where(o_s[:, None] >= o_s[None, :], b_Aab_inv, 0) + b_Aak = tl.where(o_s[:, None] > o_s[None, :], b_Aak, 0) + # let's use tf32 here + b_Aak = tl.dot(b_Aab_inv, b_Aak) + # (SY 01/04) should be bf16 or tf32? To verify. + b_Aak = b_Aak.to(v.dtype.element_ty, fp_downcast_rounding="rtne") + b_Aab_inv = b_Aab_inv.to(ag.dtype.element_ty, fp_downcast_rounding="rtne") + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_ag = tl.make_block_ptr(ag + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T * K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_ag = tl.make_block_ptr(ag + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_ag = tl.load(p_ag, boundary_check=(0, 1)) + b_w = tl.dot(b_Aab_inv, b_ag) # both bf16 or fp16 + tl.store(p_w, b_w.to(p_w.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T * V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.dot(b_Aak, b_v) # both bf16 or fp16 + tl.store(p_u, b_u.to(p_u.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = ag.shape + else: + B, T, H, K = ag.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + A_ab_inv = torch.empty_like(A_ab) + fwd_fn[(NT, B * H)]( + A_ab=A_ab, + A_ab_inv=A_ab_inv, + offsets=offsets, + indices=indices, + T=T, + H=H, + BT=BT, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_wu( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return w, u, A_ab_inv + + +def fwd_wu( + ag: torch.Tensor, + v: torch.Tensor, + A_ak: torch.Tensor, + A_ab_inv: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *ag.shape, v.shape[-1] + else: + B, T, H, K, V = *ag.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 64) + BV = min(triton.next_power_of_2(V), 64) + + u = torch.empty_like(v) + w = torch.empty_like(ag) + fwd_wu_kernel[(NT, B*H)]( + ag=ag, + v=v, + A_ak=A_ak, + A_ab_inv=A_ab_inv, + w=w, + u=u, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u diff --git a/fla/ops/generalized_delta_rule/iplr/__init__.py b/fla/ops/generalized_delta_rule/iplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e44d2a773b31f43fce68c5a9d1e67a3b33f42411 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_iplr_delta_rule +from .fused_recurrent import fused_recurrent_iplr_delta_rule + +__all__ = [ + 'chunk_iplr_delta_rule', + 'fused_recurrent_iplr_delta_rule' +] diff --git a/fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-311.pyc b/fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cfe2938e72c608a2e1ec51c665348a8aa7096e0 Binary files /dev/null and b/fla/ops/generalized_delta_rule/iplr/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-311.pyc b/fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad0c89793620e456634ec34bfec3d3e69b7ef7fb Binary files /dev/null and b/fla/ops/generalized_delta_rule/iplr/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd68b42dceeb1d1e272be239e44da7d894555b9c Binary files /dev/null and b/fla/ops/generalized_delta_rule/iplr/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-311.pyc b/fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa810c3c8239a91a6c7a1b34b1fb74e8dd9a1b3b Binary files /dev/null and b/fla/ops/generalized_delta_rule/iplr/__pycache__/wy_fast.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/iplr/chunk.py b/fla/ops/generalized_delta_rule/iplr/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..07f76533b10f022ba2dd1bc2af075c5a4f537760 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/chunk.py @@ -0,0 +1,528 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_delta_h import prepare_chunk_offsets +from fla.ops.generalized_delta_rule.iplr.wy_fast import fwd_prepare_wy_repr +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, use_cuda_graph + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_generalized_iplr_delta_rule_fwd_kernel_h( + k, + v, + d, + b, + u, + v_new, + h, + h0, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_b = tl.make_block_ptr(b+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_u = tl.make_block_ptr(u+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_v2 = tl.dot(b_d, b_h.to(b_d.dtype)) + tl.load(p_u, boundary_check=(0, 1)) + b_hc += tl.dot(b_k, b_v) + b_hc += tl.dot(b_b, b_v2.to(b_k.dtype)) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_generalized_iplr_delta_rule_fwd_kernel_o( + q, + k, + v, + u, + b, + h, + o, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + b += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + u += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V) + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_Aqk = tl.zeros([BT, BT], dtype=tl.float32) + b_Aqb = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_b = tl.make_block_ptr(b, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_Aqk += tl.dot(b_q, b_k) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_Aqb += tl.dot(b_q, b_b) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_Aqk = tl.where(m_A, b_Aqk, 0) + b_Aqb = tl.where(m_A, b_Aqb, 0) + + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_u = tl.load(p_u, boundary_check=(0, 1)) + b_o = (b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_u.dtype), b_u)) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_generalized_iplr_delta_rule_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + b: torch.Tensor, + h: torch.Tensor, + scale: Optional[float] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + NT, + B * H + ) + chunk_generalized_iplr_delta_rule_fwd_kernel_o[grid]( + q=q, + k=k, + v=v, + u=v_new, + b=b, + h=h, + o=o, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o + + +def chunk_generalized_iplr_delta_rule_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + b: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, u.shape[-1] + else: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + + if check_shared_mem('hopper', k.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + elif check_shared_mem('ampere', k.device.index): # A100 + BV = 32 + BC = 32 + else: + BV = 16 + BC = 16 + + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + else: + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + + chunk_generalized_iplr_delta_rule_fwd_kernel_h[grid]( + k=k, + v=v, + d=w, + b=b, + u=u, + v_new=v_new, + h=h, + h0=initial_state, + ht=final_state, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, final_state + + +def chunk_generalized_iplr_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + w, u, _ = fwd_prepare_wy_repr( + a=a, + b=b, + k=k, + v=v, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + + h, v_new, final_state = chunk_generalized_iplr_delta_rule_fwd_h( + k=k, + v=v, + b=b, + w=w, + u=u, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + o = chunk_generalized_iplr_delta_rule_fwd_o( + q=q, + k=k, + v=v, + v_new=v_new, + b=b, + h=h, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return o, final_state + + +class ChunkGeneralizedIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True + ): + chunk_size = 64 + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + o, final_state = chunk_generalized_iplr_delta_rule_fwd( + q=q, + k=k, + v=v, + a=a, + b=b, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, + do: torch.Tensor, + dht: torch.Tensor + ): + raise NotImplementedError( + "Backward pass for ChunkGeneralizedIPLRDeltaRuleFunction is not implemented yet. " + "Stay tuned!" + ) + + +@torch.compiler.disable +def chunk_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + a (torch.Tensor): + activations of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + b (torch.Tensor): + betas of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkDeltaRuleFunction does not support float32. Please use bfloat16." + + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + scale = k.shape[-1] ** -0.5 if scale is None else scale + o, final_state = ChunkGeneralizedIPLRDeltaRuleFunction.apply( + q, + k, + v, + a, + b, + scale, + initial_state, + output_final_state, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..8cea0c212b9b4155f5a28f1ae3ac12e865a03d59 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/fused_recurrent.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2025, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BV in [32, 64] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=["BK"], +) +@triton.jit +def fused_recurrent_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V]. + a, # a [B, H, L, K] + b, # b [B, H, L, K] + o, # output [B, H, L, V] + ha, # tmp variable [B, H, L, V] for storing intermediate results of (h * a[None, :]).sum(0) + h0, # initial hidden state [B, H, K, V] + ht, # final hidden state [B, H, K, V] + offsets, # varlen offsets + scale, # K ** -0.5 + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + if HEAD_FIRST: + p_q = q + i_nh * T*K + tl.arange(0, BK) + p_k = k + i_nh * T*K + tl.arange(0, BK) + p_a = a + i_nh * T*K + tl.arange(0, BK) + p_b = b + i_nh * T*K + tl.arange(0, BK) + p_o = o + i_nh * T*V + i_v * BV + tl.arange(0, BV) + p_v = v + i_nh * T*V + i_v * BV + tl.arange(0, BV) + p_ha = ha + i_nh * T*V + i_v * BV + tl.arange(0, BV) + else: + p_q = q + (bos * H + i_h) * K + tl.arange(0, BK) + p_k = k + (bos * H + i_h) * K + tl.arange(0, BK) + p_a = a + (bos * H + i_h) * K + tl.arange(0, BK) + p_b = b + (bos * H + i_h) * K + tl.arange(0, BK) + p_ha = ha + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_v = v + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + p_o = o + (bos * H + i_h) * V + i_v * BV + tl.arange(0, BV) + + mask_k = tl.arange(0, BK) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + # to store + tmp = tl.sum(b_h * b_a[None, :], axis=1) + b_h += (tmp[:, None] * b_b[None, :] + b_k[None, :] * b_v[:, None]) + _o = b_h * b_q[None, :] + _o = tl.sum(_o, axis=1) + tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_v) + tl.store(p_ha, tmp.to(p_ha.dtype.element_ty), mask=mask_v) + p_q += K if HEAD_FIRST else K*H + p_k += K if HEAD_FIRST else K*H + p_o += V if HEAD_FIRST else V*H + p_v += V if HEAD_FIRST else V*H + p_ha += V if HEAD_FIRST else V*H + p_a += K if HEAD_FIRST else K*H + p_b += K if HEAD_FIRST else K*H + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K * V + (tl.arange(0, BK)[None, :]) * V + ((i_v * BV + tl.arange(0, BV))[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_DHT': lambda args: args['dht'] is not None, + 'USE_DH0': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=["BK", "BV"], +) +@triton.jit +def fused_recurrent_bwd_kernel( + # B: batch_size, H: n_heads, T: seq_len, D: b_dhead + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + a, # a [B, H, L, K] + b, # b [B, H, L, K] + ha, # ha [B, H, L, V] + dht, # gradient of final state [B, H, K, V] + dh0, # gradient of initial state [B, H, K, V] + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + da, # gradient of a [NV, B, H, L, K] + db, # gradient of b [NV, B, H, L, K] + dha, # gradient of ha [NK, B, H, L, V] + h0, # initial state [B, H, K, V] + scale, # K ** -0.5 + offsets, # offsets + B, # batch_size + H, # n_heads + T, # seq_len + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state h0 + USE_DH0: tl.constexpr, # whether to use dh0 + USE_DHT: tl.constexpr, # whether to use dht + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + dk += i_v * B * H * K * T + db += i_v * B * H * K * T + dq += i_v * B * H * K * T + da += i_v * B * H * K * T + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + mask_k = tl.arange(0, BK) < K + mask_v = (tl.arange(0, BV) + i_v * BV) < V + + q += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV) + ha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV) + a += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + b += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + do += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV) + dq += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + dk += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + dv += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV) + da += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + db += (i_nh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + dha += (i_nh * T*V + i_v * BV) if HEAD_FIRST else ((bos * H + i_h) * V + i_v * BV) + + p_q = q + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_k = k + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_v = v + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H) + p_ha = ha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H) + p_a = a + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_b = b + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_do = do + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H) + p_dk = dk + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_dv = dv + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H) + p_dha = dha + tl.arange(0, BV) + (T - 1) * V * (1 if HEAD_FIRST else H) + p_db = db + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_da = da + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + p_dq = dq + tl.arange(0, BK) + (T - 1) * K * (1 if HEAD_FIRST else H) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_DHT: + p_ht = dht + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_dh += tl.load(p_ht, mask=mask_k[:, None] & mask_v[None, :], other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + + b_dh += b_q[:, None] * b_do[None, :] + d_k = tl.sum(b_dh * b_v[None, :], axis=1) + d_v = tl.sum(b_dh * b_k[:, None], axis=0) + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_v) + + b_dha = tl.sum(b_dh * b_b[:, None], axis=0) + tl.store(p_dha, b_dha.to(p_dha.dtype.element_ty), mask=mask_v) + b_db = tl.sum(b_dh * b_ha[None, :], axis=1) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), mask=mask_k) + + b_dh += b_dha[None, :] * b_a[:, None] + p_do -= V if HEAD_FIRST else V*H + p_q -= K if HEAD_FIRST else K*H + p_k -= K if HEAD_FIRST else K*H + p_v -= V if HEAD_FIRST else V*H + p_dk -= K if HEAD_FIRST else K*H + p_dv -= V if HEAD_FIRST else V*H + p_b -= K if HEAD_FIRST else K*H + p_db -= K if HEAD_FIRST else K*H + p_a -= K if HEAD_FIRST else K*H + p_dha -= V if HEAD_FIRST else V*H + p_ha -= V if HEAD_FIRST else V*H + + if USE_DH0: + p_dh0 = dh0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_k[:, None] & mask_v[None, :]) + + tl.debug_barrier() + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_k[:, None] & mask_v[None, :] + p_h0 = h0 + i_nh * K * V + (tl.arange(0, BK)[:, None]) * V + ((i_v * BV + tl.arange(0, BV))[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + p_k = k + tl.arange(0, BK) + p_v = v + tl.arange(0, BV) + p_ha = ha + tl.arange(0, BV) + p_do = do + tl.arange(0, BV) + p_dha = dha + tl.arange(0, BV) + p_da = da + tl.arange(0, BK) + p_dq = dq + tl.arange(0, BK) + p_b = b + tl.arange(0, BK) + + for i in range(0, T): + b_dha = tl.load(p_dha, mask=mask_v, other=0).to(tl.float32) + d_a = tl.sum(b_dha[None, :] * b_h, axis=1) + tl.store(p_da, d_a.to(p_da.dtype.element_ty), mask=mask_k) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b, mask=mask_k, other=0).to(tl.float32) + b_ha = tl.load(p_ha, mask=mask_v, other=0).to(tl.float32) + b_h += b_k[:, None] * b_v[None, :] + b_b[:, None] * b_ha[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += K if HEAD_FIRST else K*H + p_do += V if HEAD_FIRST else V*H + p_v += V if HEAD_FIRST else V*H + p_da += K if HEAD_FIRST else K*H + p_dha += V if HEAD_FIRST else V*H + p_ha += V if HEAD_FIRST else V*H + p_dq += K if HEAD_FIRST else K*H + p_b += K if HEAD_FIRST else K*H + + +class FusedRecurrentIPLRDeltaRuleFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, a, b, scale=None, initial_state=None, output_final_state=False, offsets=None, head_first=False): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + + BK = triton.next_power_of_2(K) + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float32) + else: + final_state = None + + ha = torch.empty_like(v, dtype=torch.float32) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + N * H + ) + o = torch.empty_like(v) + fused_recurrent_fwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + o=o, + ha=ha, + h0=initial_state, + ht=final_state, + scale=scale, + offsets=offsets, + H=H, + T=T, + K=K, + V=V, + BK=BK, + HEAD_FIRST=head_first + ) + ctx.save_for_backward(q, k, v, a, b, ha, initial_state) + ctx.scale = scale + ctx.head_first = head_first + ctx.offsets = offsets + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, a, b, ha, initial_state = ctx.saved_tensors + if ctx.head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + + N = B if ctx.offsets is None else len(ctx.offsets) - 1 + scale = ctx.scale + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 64) + NV = triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape) + dk = k.new_empty(NV, *k.shape) + da = a.new_empty(NV, *a.shape) + db = b.new_empty(NV, *b.shape) + dv = torch.empty_like(v) + dha = torch.empty_like(ha) + grid = (NV, N * H) + + if initial_state is not None and initial_state.requires_grad: + dh0 = torch.empty_like(initial_state, dtype=torch.float32) + else: + dh0 = None + + fused_recurrent_bwd_kernel[grid]( + q=q, + k=k, + v=v, + a=a, + b=b, + ha=ha, + dht=dht, + dh0=dh0, + do=do, + dq=dq, + dk=dk, + dv=dv, + da=da, + db=db, + dha=dha, + h0=initial_state, + scale=scale, + offsets=ctx.offsets, + B=B, + H=H, + T=T, + K=K, + V=V, + BK=BK, + BV=BV, + HEAD_FIRST=ctx.head_first + ) + dq = dq.sum(0) + dk = dk.sum(0) + da = da.sum(0) + db = db.sum(0) + return dq.to(q), dk.to(k), dv.to(v), da.to(a), db.to(b), None, dh0, None, None, None + + +def fused_recurrent_iplr_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + offsets: torch.Tensor = None, + head_first: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + This function computes the recurrence S_t = S_t @ (I + a_t b_t^T) + v_t k_t^T in a recurrent manner. + + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` + a (torch.Tensor): + as of shape `[B, H, T, K]` + b (torch.Tensor): + bs of shape `[B, H, T, K]` + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + offsets (Optional[torch.Tensor]): + + """ + if offsets is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(offsets) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = q.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + o, final_state = FusedRecurrentIPLRDeltaRuleFunction.apply( + q, k, v, a, b, scale, initial_state, output_final_state, offsets, head_first) + return o, final_state diff --git a/fla/ops/generalized_delta_rule/iplr/naive.py b/fla/ops/generalized_delta_rule/iplr/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..9da977011e943f7432be09b144c115d7661911ac --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/naive.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +# S_t = S_t @ (I + alpha_t beta_t^T) + v_t k_t^T +# q, k, alpha, beta [B, H, L, D_K] +# v [B, H, L, D_V] +def iplr_recurrence(q, k, v, alpha, beta, initial_state=None, output_final_state=True): + orig_dtype = q.dtype + b, h, l, d_k = q.shape + q, k, v, beta = map(lambda x: x.float(), [q, k, v, beta]) + d_v = v.shape[-1] + o = torch.zeros_like(v) + S = torch.zeros(b, h, d_k, d_v).to(v) + q = q * (d_k ** -0.5) + + if initial_state is not None: + S += initial_state + + for i in range(l): + _k = k[:, :, i] + _q = q[:, :, i] + _v = v[:, :, i] + _alpha = alpha[:, :, i] + _beta = beta[:, :, i] + _kv = _k[..., None] * _v[..., None, :] + (S.clone() * _alpha[..., None]).sum(-2, keepdim=True) * _beta[..., None] + S = S + _kv + o[:, :, i] = torch.einsum('bhd,bhdm->bhm', _q, S) + S = None if output_final_state is False else S + return o.to(orig_dtype), S + + +def iplr_chunkwise(q, k, v, alpha, beta, initial_state=None, output_final_state=True, chunk_size=32): + b, h, l, d_k = q.shape + d_v = v.shape[-1] + q = q * (d_k ** -0.5) + v = v + assert l % chunk_size == 0 + + S = k.new_zeros(b, h, d_k, d_v) + if initial_state is not None: + S += initial_state + + # note that diagonal is masked. + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0) + q, k, v, alpha, beta = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size), [q, k, v, alpha, beta]) + + v2 = (alpha @ k.transpose(-1, -2)).masked_fill_(mask, 0) @ v + attn = (alpha @ beta.transpose(-1, -2)).masked_fill(mask, 0) + for i in range(1, chunk_size): + attn[..., i, :i] = attn[..., i, :i] + (attn[..., i, :, None].clone() * attn[..., :, :i].clone()).sum(-2) + + attn = attn + torch.eye(chunk_size, dtype=torch.float, device=q.device) + u = attn @ v2 + w = attn @ alpha + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, l // chunk_size): + q_i, k_i, v_i, u_i, w_i, beta_i = q[:, :, i], k[:, :, i], v[:, :, i], u[:, :, i], w[:, :, i], beta[:, :, i] + o_1 = (q_i @ k_i.transpose(-1, -2)).masked_fill_(mask, 0) @ v_i + v2_i = u_i + w_i @ S + o_2 = (q_i @ beta_i.transpose(-1, -2)).masked_fill_(mask, 0) @ (v2_i) + o_3 = q_i @ S + o[:, :, i] = o_1 + o_2 + o_3 + S = S + k_i.transpose(-1, -2) @ v_i + beta_i.transpose(-1, -2) @ v2_i + S = None if output_final_state is False else S + return rearrange(o, 'b h n c d -> b h (n c) d'), S diff --git a/fla/ops/generalized_delta_rule/iplr/wy_fast.py b/fla/ops/generalized_delta_rule/iplr/wy_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdfa7091500873765a36c6ef86506a203f4be19 --- /dev/null +++ b/fla/ops/generalized_delta_rule/iplr/wy_fast.py @@ -0,0 +1,338 @@ + +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, is_nvidia_hopper + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk32( + a, + b, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, # dummy placeholder + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + else: + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_A += tl.dot(b_a, b_b) + + b_A = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_A, 0) + for i in range(1, BT): + mask = tl.arange(0, BT) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BT) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += tl.arange(0, BT)[:, None] == tl.arange(0, BT)[None, :] + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16] + ], + key=['BK'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_prepare_wy_repr_kernel_chunk64( + a, + b, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + b_A2 = tl.zeros([BC, BC], dtype=tl.float32) + b_A3 = tl.zeros([BC, BC], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_a1 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + else: + p_a1 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BC, BK), (1, 0)) + p_a2 = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + BC, i_k * BK), (BC, BK), (1, 0)) + p_b1 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT), (BK, BC), (0, 1)) + p_b2 = tl.make_block_ptr(b + (bos * H + i_h) * K, (K, T), (1, K*H), (i_k * BK, i_t * BT + BC), (BK, BC), (0, 1)) + b_a1 = tl.load(p_a1, boundary_check=(0, 1)) + b_a2 = tl.load(p_a2, boundary_check=(0, 1)) + b_b1 = tl.load(p_b1, boundary_check=(0, 1)) + b_b2 = tl.load(p_b2, boundary_check=(0, 1)) + b_A += tl.dot(b_a1, b_b1, allow_tf32=False) + b_A2 += tl.dot(b_a2, b_b2, allow_tf32=False) + b_A3 += tl.dot(b_a2, b_b1, allow_tf32=False) + + b_A = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A, 0) + b_A2 = tl.where(tl.arange(0, BC)[:, None] > tl.arange(0, BC)[None, :], b_A2, 0) + + for i in range(1, BC): + mask = tl.arange(0, BC) == i + b_a = tl.sum(tl.where(mask[:, None], b_A, 0), 0) + b_a2 = tl.sum(tl.where(mask[:, None], b_A2, 0), 0) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) * (tl.arange(0, BC) < i) + b_a2 = b_a2 + tl.sum(b_a2[:, None] * b_A2, 0) * (tl.arange(0, BC) < i) + b_A = tl.where(mask[:, None], b_a, b_A) + b_A2 = tl.where(mask[:, None], b_a2, b_A2) + + # blockwise computation of lower triangular matrix's inverse + # i.e., [A11, 0; A21, A22]^-1 = [A11^-1, 0; -A22^-1 A21 A11^-1, A22^-1] + b_A += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A2 += tl.arange(0, BC)[:, None] == tl.arange(0, BC)[None, :] + b_A3 = tl.dot(tl.dot(b_A2, b_A3, allow_tf32=False), b_A, allow_tf32=False) + + if HEAD_FIRST: + p_A1 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + else: + p_A1 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BC, BC), (1, 0)) + p_A2 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, BC), (BC, BC), (1, 0)) + p_A3 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT + BC, 0), (BC, BC), (1, 0)) + p_A4 = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, BC), (BC, BC), (1, 0)) + tl.store(p_A1, b_A.to(p_A1.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A2, b_A2.to(p_A2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_A3, b_A3.to(p_A3.dtype.element_ty), boundary_check=(0, 1)) + # causal mask + tl.store(p_A4, tl.zeros([BC, BC], dtype=tl.float32).to(p_A4.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in NUM_WARPS + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def fwd_wu_kernel( + w, + u, + a, + k, + v, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_Aak = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_a = tl.load(p_a, boundary_check=(0, 1)) + b_w = tl.dot(b_A, b_a) + b_Aak += tl.dot(b_a, tl.trans(b_k)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + b_Aak = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_Aak, 0) + b_Aak = b_Aak.to(k.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = tl.dot(b_Aak, b_v).to(v.dtype.element_ty) + b_u = tl.dot(b_A, b_v) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + +def fwd_prepare_wy_repr( + a: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K = a.shape + else: + B, T, H, K = a.shape + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(BT, 32) + BK = min(triton.next_power_of_2(K), 64) + + A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=a.device, dtype=a.dtype) + fwd_fn = fwd_prepare_wy_repr_kernel_chunk64 if BT == 64 else fwd_prepare_wy_repr_kernel_chunk32 + + fwd_fn[(NT, B * H)]( + a=a, + b=b, + A=A, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + BK=BK, + BC=BC, + HEAD_FIRST=head_first + ) + w, u = fwd_wu( + a=a, + v=v, + k=k, + A=A, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return w, u, A + + +def fwd_wu( + a: torch.Tensor, + v: torch.Tensor, + k: torch.Tensor, + A: torch.Tensor, + offsets: Optional[torch.LongTensor], + indices: Optional[torch.LongTensor], + head_first: bool, + chunk_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *a.shape, v.shape[-1] + else: + B, T, H, K, V = *a.shape, v.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + + u = torch.empty_like(v) + w = torch.empty_like(a) + fwd_wu_kernel[(NT, B*H)]( + a=a, + v=v, + w=w, + u=u, + A=A, + k=k, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return w, u diff --git a/fla/ops/gla/__init__.py b/fla/ops/gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..367c85442a26fe56516716622433f8b6f87afd2c --- /dev/null +++ b/fla/ops/gla/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_gla +from .fused_chunk import fused_chunk_gla +from .fused_recurrent import fused_recurrent_gla + +__all__ = [ + 'chunk_gla', + 'fused_chunk_gla', + 'fused_recurrent_gla' +] diff --git a/fla/ops/gla/__pycache__/__init__.cpython-311.pyc b/fla/ops/gla/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f410504e2ca37bf3fd0899605c86d149c3ece285 Binary files /dev/null and b/fla/ops/gla/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/gla/__pycache__/chunk.cpython-311.pyc b/fla/ops/gla/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6c60431beb4094456ca897648f6b5486edc99aa Binary files /dev/null and b/fla/ops/gla/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/gla/__pycache__/fused_chunk.cpython-311.pyc b/fla/ops/gla/__pycache__/fused_chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0638c9c2f60d6728bb098bfb3e753762e89821c Binary files /dev/null and b/fla/ops/gla/__pycache__/fused_chunk.cpython-311.pyc differ diff --git a/fla/ops/gla/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/gla/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a4630ceceec57566830501f5d0b1e46205974cb Binary files /dev/null and b/fla/ops/gla/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/gla/chunk.py b/fla/ops/gla/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e40f492d0945df1ec2890c5da9f09cf4b6de9c --- /dev/null +++ b/fla/ops/gla/chunk.py @@ -0,0 +1,1486 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.utils import chunk_local_cumsum +from fla.ops.utils.op import exp, safe_exp +from fla.utils import check_shared_mem, input_guard + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BC"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_inter( + q, + k, + g, + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_i, i_j = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + if i_i <= i_j: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_qg = b_q * exp(b_g - b_gn[None, :]) * scale + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp(b_gn[:, None] - b_gk) + # [BC, BC] using tf32 to improve precision here. + b_A += tl.dot(b_qg, b_kg) + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=["BK", "BT"] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_intra( + q, + k, + g, + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + else: + o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) + p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A = tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i >= j, b_A * scale, 0.) + + tl.store(A + o_A + j, b_A, mask=m_A) + p_k += K if HEAD_FIRST else H*K + p_gk += K if HEAD_FIRST else H*K + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC', 'BK'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_intra_split( + q, + k, + g, + A, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_tc // NC, i_tc % NC + i_j = i_i + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_i * BC >= T: + return + + o_i = tl.arange(0, BC) + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + + if HEAD_FIRST: + o_A = (i_k * B*H + i_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + p_gk = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_j * BC) * K + o_k, BK), BK) + else: + o_A = (i_k * all + bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BC + i_h * BC + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = k + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + p_gk = g + (bos + i_t * BT + i_j * BC) * H*K + i_h * K + o_k + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + b_A = tl.zeros([BC], dtype=tl.float32) + b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32) + b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) + b_A += tl.sum(b_q * b_k[None, :] * exp(b_g - b_gk[None, :]), 1) + b_A = tl.where(o_i >= j, b_A * scale, 0.) + tl.store(A + o_A + j, b_A, mask=m_A) + p_k += K if HEAD_FIRST else H*K + p_gk += K if HEAD_FIRST else H*K + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BC'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_A_kernel_intra_sub_intra_merge( + A, + A2, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + NK: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + all = T + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + all = B * T + + if i_t * BT + i_c * BC >= T: + return + + b_A = tl.zeros([BC, BC], dtype=tl.float32) + for i_k in range(0, NK): + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + (i_k*B*H+i_bh)*T*BC, (T, BC), (BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (i_k*all+bos)*H*BC+i_h*BC, (T, BC), (H*BC, 1), (i_t*BT + i_c*BC, 0), (BC, BC), (1, 0)) + b_A += tl.load(p_A, boundary_check=(0, 1)) + if HEAD_FIRST: + p_A2 = tl.make_block_ptr(A2 + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + else: + p_A2 = tl.make_block_ptr(A2 + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0)) + tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype) + b_o += tl.dot(b_A, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BK', 'NC', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_intra( + q, + k, + g, + dA, + dq, + dk, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_t, i_i = i_c // NC, i_c % NC + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + if i_t * BT + i_i * BC >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_g = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + if HEAD_FIRST: + p_gn = g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + else: + p_gn = g + (bos + i_t * BT + i_i * BC) * H*K + i_h*K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA+(bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t*BT+i_i*BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = (b_k * exp(b_gn[None, :] - b_gk)) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kg) + b_dq *= exp(b_g - b_gn[None, :]) + + o_i = tl.arange(0, BC) + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + if HEAD_FIRST: + o_dA = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + p_kj = tl.max_contiguous(tl.multiple_of(k + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gkj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + o_dA = bos*H*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_i * BC + p_kj = k + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gkj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dq = tl.make_block_ptr(dq + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dq += tl.where(m_i, b_dA[:, None] * b_kj[None, :] * exp(b_g - b_gkj[None, :]), 0.) + p_kj += K if HEAD_FIRST else H*K + p_gkj += K if HEAD_FIRST else H*K + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + if HEAD_FIRST: + p_gn = g + (i_bh * T + min(i_t * BT + i_i * BC + BC, T) - 1) * K + o_k + p_gn = tl.max_contiguous(tl.multiple_of(p_gn, BK), BK) + else: + p_gn = g + (bos + min(i_t * BT + i_i * BC + BC, T) - 1) * H*K + i_h * K + o_k + + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t*BT + i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t*BT + i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (BT, T), (1, BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_gq = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + (bos*H+i_h)*BT, (BT, T), (1, H*BT), (i_i*BC, i_t*BT + i_j*BC), (BC, BC), (0, 1)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gq = tl.load(p_gq, boundary_check=(0, 1)) + b_qg = b_q * safe_exp(b_gq - b_gn[None, :]) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + # (SY 09/17) important to not use bf16 here to have a good precision. + b_dk += tl.dot(b_dA, b_qg) + b_dk *= exp(b_gn[None, :] - b_gk) + if HEAD_FIRST: + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + p_qj = tl.max_contiguous(tl.multiple_of(q + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_gqj = tl.max_contiguous(tl.multiple_of(g + (i_bh * T + i_t * BT + i_i * BC) * K + o_k, BK), BK) + p_dk = tl.make_block_ptr(dk + i_bh*T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + else: + o_dA = bos*H*BT + (i_t * BT + i_i * BC) * H*BT + i_h * BT + i_i * BC + tl.arange(0, BC) + p_qj = q + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_gqj = g + (bos + i_t * BT + i_i * BC) * H*K + i_h * K + o_k + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dA = tl.load(dA + o_dA + j * (1 if HEAD_FIRST else H) * BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_gqj = tl.load(p_gqj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_gqj[None, :] - b_gk), 0.) + p_qj += K if HEAD_FIRST else H*K + p_gqj += K if HEAD_FIRST else H*K + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + ], + key=['BV', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_dA( + v, + do, + dA, + offsets, + indices, + scale, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + else: + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dA += tl.dot(b_do, b_v) + if HEAD_FIRST: + p_dA = tl.make_block_ptr(dA + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA = tl.where(m_s, b_dA * scale, 0.) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_dv( + k, + g, + A, + do, + dh, + dv, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # (SY 09/17) important to disallow tf32 here to maintain a good precision. + b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + min(i_t * BT + BT, T) * K - K + o_k, BK), BK) + p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(i_t * BT + BT, T) - 1)*H*K + i_h * K + o_k + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BT, BV] + # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_gla_bwd_kernel_inter( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + if HEAD_FIRST: + p_gk = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * T*K + (min(T, i_t * BT + BT)-1) * K + o_k, BK), BK) + else: + p_gk = tl.make_block_ptr(g + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(T, i_t * BT + BT)-1) * H*K + i_h * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK,], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + else: + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_dgk *= exp(b_gn) + b_dq *= scale + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gk) + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + # tl.debug_barrier() + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] + # Buggy due to strange triton compiler issue. + # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) + # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] + if HEAD_FIRST: + p_dq = tl.make_block_ptr(dq2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_dq = tl.make_block_ptr(dq2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_intra_gk( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BC = min(16, BT) + NC = triton.cdiv(BT, BC) + + A = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + grid = (NT, NC * NC, B * H) + chunk_gla_fwd_A_kernel_intra_sub_inter[grid]( + q, + k, + g, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + NC=NC, + HEAD_FIRST=head_first + ) + + grid = (NT, NC, B * H) + # load the entire [BC, K] blocks into SRAM at once + if K <= 256: + BK = triton.next_power_of_2(K) + chunk_gla_fwd_A_kernel_intra_sub_intra[grid]( + q, + k, + g, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + HEAD_FIRST=head_first + ) + # split then merge + else: + BK = min(128, triton.next_power_of_2(K)) + NK = triton.cdiv(K, BK) + A_intra = q.new_empty(NK, B, *((H, T) if head_first else (T, H)), BC, dtype=torch.float) + + grid = (NK, NT * NC, B * H) + chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid]( + q, + k, + g, + A_intra, + offsets, + indices, + scale, + T=T, + B=B, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first + ) + + grid = (NT, NC, B * H) + chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid]( + A_intra, + A, + offsets, + indices, + T=T, + B=B, + H=H, + BT=BT, + BC=BC, + NK=NK, + HEAD_FIRST=head_first + ) + return A + + +def chunk_gla_fwd_o_gk( + q: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_gla_fwd_kernel_o[grid]( + q, + v, + g, + h, + o, + A, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o + + +def chunk_gla_bwd_dA( + v: torch.Tensor, + do: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, V = v.shape + else: + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BV = min(64, triton.next_power_of_2(V)) + + dA = v.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) + grid = (NT, B * H) + chunk_gla_bwd_kernel_dA[grid]( + v, + do, + dA, + offsets, + indices, + scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + HEAD_FIRST=head_first + ) + return dA + + +def chunk_gla_bwd_dv( + k: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dv = torch.empty_like(do) + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_gla_bwd_kernel_dv[grid]( + k, + g, + A, + do, + dh, + dv, + offsets, + indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_gla_bwd_dqk_intra( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + dA: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K = q.shape + else: + B, T, H, K = q.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BC = min(16, BT) + BK = min(64, triton.next_power_of_2(K)) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + grid = (NK, NT * NC, B * H) + chunk_gla_bwd_kernel_intra[grid]( + q, + k, + g, + dA, + dq, + dk, + offsets, + indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + HEAD_FIRST=head_first + ) + return dq, dk + + +def chunk_gla_bwd_dqkg( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dg = torch.empty_like(g) + # work around triton compiler bugs. + dq2 = torch.empty_like(dq) + dk2 = torch.empty_like(dk) + def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) + chunk_gla_bwd_kernel_inter[grid]( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return dq2, dk2, dg + + +def chunk_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + g_cumsum: Optional[torch.Tensor], + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + if g_cumsum is None: + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) + + h, ht = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=g_cumsum, + gv=None, + h0=initial_state, + output_final_state=output_final_state, + states_in_fp32=False, + offsets=offsets, + head_first=head_first, + chunk_size=BT + ) + + # the intra A is kept in fp32 + # the computation has very marginal effect on the entire throughput + A = chunk_gla_fwd_intra_gk( + q=q, + k=k, + g=g_cumsum, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + o = chunk_gla_fwd_o_gk( + q=q, + v=v, + g=g_cumsum, + A=A, + h=h, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return g_cumsum, A, h, ht, o + + +def chunk_gla_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + g_cumsum: Optional[torch.Tensor], + scale: float, + initial_state: torch.Tensor, + h: torch.Tensor, + A: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +): + T = q.shape[2] if head_first else q.shape[1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + if g_cumsum is None: + g_cumsum = chunk_local_cumsum(g, BT, offsets=offsets, indices=indices, head_first=head_first) + + if h is None: + h, _ = chunk_fwd_h( + k=k, + v=v, + g=None, + gk=g_cumsum, + gv=None, + h0=initial_state, + output_final_state=False, + offsets=offsets, + head_first=head_first, + chunk_size=BT, + states_in_fp32=True + ) + dh, dh0 = chunk_bwd_dh( + q=q, + k=k, + v=v, + g=None, + gk=g_cumsum, + gv=None, + do=do, + h0=initial_state, + dht=dht, + scale=scale, + offsets=offsets, + head_first=head_first, + chunk_size=BT, + states_in_fp32=True + ) + + dv = chunk_gla_bwd_dv( + k=k, + g=g_cumsum, + A=A, + do=do, + dh=dh, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + + # dq dk in fp32 + dA = chunk_gla_bwd_dA( + v=v, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk = chunk_gla_bwd_dqk_intra( + q=q, + k=k, + g=g_cumsum, + dA=dA, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk, dg = chunk_gla_bwd_dqkg( + q=q, + k=k, + v=v, + h=h, + g=g_cumsum, + do=do, + dh=dh, + dq=dq, + dk=dk, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return dq, dk, dv, dg, dh0 + + +class ChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + q, + k, + v, + g, + scale, + initial_state, + output_final_state, + offsets, + head_first + ): + T = q.shape[2] if head_first else q.shape[1] + chunk_size = min(64, max(16, triton.next_power_of_2(T))) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + g_cumsum, A, h, ht, o = chunk_gla_fwd( + q=q, + k=k, + v=v, + g=g, + g_cumsum=None, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + # recompute g_cumsum in bwd pass + if g.dtype != torch.float: + g_cumsum = None + else: + g = None + ctx.save_for_backward(q, k, v, g, g_cumsum, initial_state, A) + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + return o, ht + + @staticmethod + @input_guard + def backward(ctx, do, dht): + q, k, v, g, g_cumsum, initial_state, A = ctx.saved_tensors + chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + dq, dk, dv, dg, dh0 = chunk_gla_bwd( + q=q, + k=k, + v=v, + g=g, + g_cumsum=g_cumsum, + scale=scale, + h=None, + A=A, + initial_state=initial_state, + do=do, + dht=dht, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return dq.to(q), dk.to(k), dv.to(v), dg, None, dh0, None, None, None + + +@torch.compiler.disable +def chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[int] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import chunk_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = chunk_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = q.shape[-1] ** -0.5 + o, final_state = ChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state, cu_seqlens, head_first) + return o, final_state diff --git a/fla/ops/gla/fused_chunk.py b/fla/ops/gla/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..318be21402e6e1d1324a5eed4bc318f100e4c59c --- /dev/null +++ b/fla/ops/gla/fused_chunk.py @@ -0,0 +1,631 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from packaging import version + +from fla.ops.utils import chunk_local_cumsum +from fla.ops.utils.op import exp, safe_exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def prepare_qg_kg( + q, + k, + g, + qg, + kg, + scale, + T, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_g = g + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_qg = qg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + p_kg = kg + i_bh * T*K + i_c * BT * K + i_k * BK + tl.arange(0, BK) + + mask = (i_k * BK + tl.arange(0, BK)) < K + + last_decay = tl.load(g + i_bh * T*K + (i_c * BT + BT - 1) * K + i_k * BK + tl.arange(0, BK)) + + for _ in range(BT): + b_q = tl.load(p_q, mask=mask, other=0) + b_k = tl.load(p_k, mask=mask, other=0) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_q *= exp(b_g) * scale + b_k *= exp(last_decay - b_g) + tl.store(p_kg, b_k.to(p_kg.dtype.element_ty), mask=mask) + tl.store(p_qg, b_q.to(p_qg.dtype.element_ty), mask=mask) + p_q += K + p_g += K + p_k += K + p_kg += K + p_qg += K + + +@triton.jit(do_not_specialize=['T']) +def bwd_decay_global_cumsum( + dq_inner, + dq_inter, + dk_inner, + dk_inter, + q, + k, + g, + dg, + T, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_q = q + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_k = k + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_g = g + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dg = dg + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dq_inner = dq_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dk_inner = dk_inner + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dq_inter = dq_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + p_dk_inter = dk_inter + i_bh * T*K + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * K + cum_grad_dg = tl.zeros([BK], dtype=tl.float32) + mask = (i_k * BK + tl.arange(0, BK)) < K + last_g = tl.zeros([BK], dtype=tl.float32) + for j in range(BT-1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + if j == (BT-1): + last_g = b_g + b_dq1 = tl.load(p_dq_inner, mask=mask, other=0) + b_dq2 = tl.load(p_dq_inter, mask=mask, other=0) + b_dq2 *= exp(b_g) + b_dq = b_dq1 + b_dq2 + tl.store(p_dq_inter, b_dq, mask=mask) + b_dk1 = tl.load(p_dk_inner, mask=mask, other=0) + b_dk2 = tl.load(p_dk_inter, mask=mask, other=0) + b_dk2 *= safe_exp(last_g - b_g) + b_dk = b_dk1 + b_dk2 + tl.store(p_dk_inter, b_dk, mask=mask) + b_q = tl.load(p_q, mask=mask, other=0) + b_k = tl.load(p_k, mask=mask, other=0) + b_dg = b_dq * b_q - b_dk * b_k + cum_grad_dg += b_dg + tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask) + p_g -= K + p_k -= K + p_q -= K + p_dq_inner -= K + p_dk_inner -= K + p_dq_inter -= K + p_dk_inter -= K + p_dg -= K + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_gla_fwd_kernel( + q, + k, + v, + g, + o, + h0, + ht, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + i_bh * T*K + (BT - 1) * K + i_k * BK + tl.arange(0, BK) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k * B * H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32) + if CHECK and i == 0: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + else: + b_o = tl.dot(b_q.to(b_v.dtype), b_h.to(b_v.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[:, None] + tl.dot(b_k.to(b_v.dtype), b_v, allow_tf32=False) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_gn += BT * K + + if STORE_FINAL_STATE: + p_final = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_final, b_h.to(p_final.dtype.element_ty), boundary_check=(0, 1)) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit(do_not_specialize=['T']) +def fused_chunk_gla_bwd_kernel( + q, k, v, g, + do, + dq, + dk, + dv, + h0, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + i_bh * T*K + ((i+1) * BT - 1) * K + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh+i_v*B*H)*T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gn = tl.load(p_gn, mask=mask, other=0).to(tl.float32) + + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [V, K] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h * exp(b_gn)[None, :] + tl.dot(b_v, b_k.to(b_v.dtype), allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + + # cum = tl.zeros([BK], dtype=tl.float32) + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + i_bh * T*K + (T - (i-1) * BT - 1) * K + i_k * BK + tl.arange(0, BK) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh + i_v * B * H) * T*K, (T, K), + (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + i_k * B * H) * T*V, (T, V), + (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + # [K, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BT, K] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, V] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_db = tl.load(p_gn, mask=mask, other=0).to(tl.float32) + + # inter-chunk + # [K, V] + if CHECK and i == 1: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + else: + b_dk = tl.trans(tl.dot(b_dh.to(b_v.dtype), tl.trans(b_v), allow_tf32=False)) + b_dv = tl.dot((b_k).to(b_v.dtype), b_dh.to(b_v.dtype), allow_tf32=False) + b_dh = b_dh * exp(b_db)[:, None] + tl.dot(b_q.to(b_do.dtype), b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fwd_inner_chunk( + q, k, g, A, + scale, # K ** -0.5 + B: tl.constexpr, # B + H: tl.constexpr, # H + T, # T + K: tl.constexpr, # K + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr # BLOCK SIZE along the K dimension +): + + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + o_i = tl.arange(0, BT) + + p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_A = A + (i_bh + (i_k * B * H)) * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + for i in range(BT): + b_q = tl.load(p_q, mask=mask, other=0) * scale + b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + s = b_q[None, :] * b_k * safe_exp(b_gq[None, :] - b_g) + score = tl.sum(s, axis=1) + score = tl.where(o_i <= i, score, 0) + tl.store(p_A, score.to(p_A.dtype.element_ty)) + p_q += K + p_gq += K + p_A += BT + + +@triton.jit +def bwd_inner_chunk( + q, + k, + g, + dA, + dq, + dk, + T, # T + K: tl.constexpr, # K + # clamp_min, # minimum log value of the gate for numerical stability. default: -5 + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_g = tl.make_block_ptr(g + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + + mask = (i_k * BK + tl.arange(0, BK)) < K + o_i = tl.arange(0, BT) + + p_q = q + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_dq = dq + (i_bh) * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_gq = g + i_bh * T*K + i_k * BK + i_t * BT * K + tl.arange(0, BK) + p_dA = dA + i_bh * (tl.cdiv(T, BT) * BT * BT) + i_t * BT * BT + tl.arange(0, BT) + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + + for i in range(BT): + b_q = tl.load(p_q, mask=mask, other=0) + b_gq = tl.load(p_gq, mask=mask, other=0).to(tl.float32) + score = safe_exp(b_gq[None, :] - b_g) + score = tl.where(o_i[:, None] <= i, score, 0) + b_dA = tl.load(p_dA) + b_dA = tl.where(o_i <= i, b_dA, 0) + b_dk += (b_dA[:, None] * score * b_q[None, :]) + b_dq = tl.sum(b_dA[:, None] * score * b_k, axis=0) + tl.store(p_dq, b_dq, mask=mask) + p_q += K + p_dq += K + p_gq += K + p_dA += BT + + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dk, b_dk.to(dk.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, initial_state, output_final_state): + ctx.g_dtype = g.dtype + ctx.scale = scale + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 16 # chunk_size + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 2 + + g_org = g + # cumulative decay should be in float32, otherwise the err will be accumulated and amplified. + g = chunk_local_cumsum(g_org, chunk_size=BT) + o = q.new_empty(NK, B, H, T, V) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + + grid = (NK, triton.cdiv(T, BT), B * H) + prepare_qg_kg[grid]( + q, + k, + g, + q_g, + k_g, + scale, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1 + ) + + if output_final_state: + final_state = q.new_empty(B, H, K, V, dtype=torch.float, requires_grad=False) + else: + final_state = None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_gla_fwd_kernel[grid]( + q_g, k_g, v, g, o, initial_state, final_state, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + + # intra-chunk + chunk_size = 16 + num_chunk = T // chunk_size + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=num_chunk) + BK = min(K, 64) + NK = triton.cdiv(K, BK) + A = q.new_empty(NK, B, H, triton.cdiv(T, BT), BT, BT) + grid = (NK, triton.cdiv(T, BT), B * H) + fwd_inner_chunk[grid]( + q, k, g, A, + scale, + B=B, + H=H, + T=T, + K=K, + BT=BT, + BK=BK, + num_stages=3, + num_warps=4 + ) + A = A.sum(0) + o2 = A @ v2 + o2 = rearrange(o2, 'b h n c d -> b h (n c) d') + # combine inner and inter + o.add_(o2) + ctx.save_for_backward(q, k, v, g_org, A, initial_state) + ctx.CHECK = CHECK + return o.to(v), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, g_org, A, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + # recomputation + # inter-chunk + BT = 16 # chunk_size + g = chunk_local_cumsum(g_org, chunk_size=BT) + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + q_g = torch.empty_like(q) + k_g = torch.empty_like(k) + grid = (NK, triton.cdiv(T, BT), B * H) + prepare_qg_kg[grid]( + q, + k, + g, + q_g, + k_g, + scale, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1 + ) + + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 2 + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + + grid = (NV, NK, B * H) + + fused_chunk_gla_bwd_kernel[grid]( + q_g, + k_g, + v, + g, + do, + dq, + dk, + dv, + initial_state, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages, + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + + # intra chunk + NT = T // BT + v2 = rearrange(v, 'b h (n c) d -> b h n c d', n=NT) + do2 = rearrange(do, 'b h (n c) d -> b h n c d', n=NT) + dA2 = (do2 @ v2.transpose(-2, -1)) * scale + dv2 = A.transpose(-1, -2) @ do2 + dv2 = rearrange(dv2, 'b h n c d -> b h (n c) d', n=NT) + + BK = min(triton.next_power_of_2(K), 16) + NK = triton.cdiv(K, BK) + dk2 = torch.empty_like(k) + dq2 = torch.empty_like(q) + + grid = (NK, NT, B * H) + bwd_inner_chunk[grid]( + q, k, g, + dA2, + dq2, + dk2, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1, + num_stages=3 + ) + + BK = min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dg = torch.empty_like(g, dtype=torch.float32) + grid = (NK, triton.cdiv(T, BT), B * H) + bwd_decay_global_cumsum[grid]( + dq2, + dq, + dk2, + dk, + q, + k, + g, + dg, + T=T, + K=K, + BT=BT, + BK=BK, + num_warps=1, + num_stages=1 + ) + dg = rearrange(dg, 'b h (n c) d -> b h n c d', c=BT) + + def rev_cumsum_exclusive(x): + cumsum_x = x.cumsum(-2) + rev_cumsum_x = cumsum_x[..., -1, None, :] - cumsum_x + return rev_cumsum_x + + rev_cumsum_dg = rev_cumsum_exclusive(dg[..., 0, :]) + dg.add_(rev_cumsum_dg.unsqueeze(-2)) + dv.add_(dv2) + dg = rearrange(dg, 'b h n c d -> b h (n c) d') + + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.g_dtype), None, None, None + + +def ceildiv(a, b): + return -(a // -b) + + +def pad(x, chunk_size=16): + T = x.shape[-2] + padded_seq_len = ceildiv(T, chunk_size) * chunk_size + if x.shape[-2] % chunk_size != 0: + x = F.pad(x, (0, 0, 0, padded_seq_len - T)) + return x + + +def fused_chunk_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: int = -1, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale == -1: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v, g = map(lambda x: x.transpose(1, 2), (q, k, v, g)) + seq_len = q.shape[-2] + q, k, v, g = map(lambda x: pad(x), [q, k, v, g]) + o, final_state = FusedChunkGLAFunction.apply(q, k, v, g, scale, initial_state, output_final_state) + o = o[..., :seq_len, :].contiguous() + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/gla/fused_recurrent.py b/fla/ops/gla/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..d211541d789809ee89a688b380626026b1dbed88 --- /dev/null +++ b/fla/ops/gla/fused_recurrent.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.common.fused_recurrent import fused_recurrent + + +def fused_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + gk (torch.Tensor): + Forget gates of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` applied to keys. + gv (torch.Tensor): + Forget gates of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` applied to values. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gla import fused_recurrent_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = fused_recurrent( + q=q, + k=k, + v=v, + g=None, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + head_first=head_first + ) + return o, final_state diff --git a/fla/ops/gla/naive.py b/fla/ops/gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..507a7395c0c28b0a9c54008e1735098cd3fbdc85 --- /dev/null +++ b/fla/ops/gla/naive.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def ceildiv(a, b): + return -(a // -b) + + +def naive_recurrent_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + gk: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False +): + dtype = q.dtype + q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) + B, H, T, K, V = *q.shape, v.shape[-1] + o = torch.zeros_like(v) + scale = K ** -0.5 + + h = q.new_zeros(B, H, K, V, dtype=torch.float32) + if initial_state is not None: + h += initial_state.float() + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = v[:, :, i] + gk_i = gk[:, :, i].exp() + kv_i = k_i[..., None] * v_i[..., None, :] + h = h * gk_i[..., None] + kv_i + o[:, :, i] = (q_i[..., None] * h).sum(-2) + + if not output_final_state: + h = None + return o.to(dtype), h diff --git a/fla/ops/hgrn/__init__.py b/fla/ops/hgrn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f2012c3c15f125271df225ce755ed3b2dbe01a83 --- /dev/null +++ b/fla/ops/hgrn/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_hgrn +from .fused_recurrent import fused_recurrent_hgrn + +__all__ = [ + 'chunk_hgrn', + 'fused_recurrent_hgrn' +] diff --git a/fla/ops/hgrn/__pycache__/__init__.cpython-311.pyc b/fla/ops/hgrn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc7bac33fed7a21d7172f6f71b3ac406a8d7deca Binary files /dev/null and b/fla/ops/hgrn/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/hgrn/__pycache__/chunk.cpython-311.pyc b/fla/ops/hgrn/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..429c84df46d870c8a73ecf6dcc72506e24343141 Binary files /dev/null and b/fla/ops/hgrn/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/hgrn/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/hgrn/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0c4527173f085fae442835b24818738d37adb45 Binary files /dev/null and b/fla/ops/hgrn/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/hgrn/chunk.py b/fla/ops/hgrn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6847622ebfb071230720b7ae6669f5412a42470b --- /dev/null +++ b/fla/ops/hgrn/chunk.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# this function implements the chunkwise form of HGRN, inspired by +# [Volodymyr Kyrylov in his blog post](https://proger.github.io/posts/scan/chunk.html) +# also refer to the `accelerated-scan` lib: https://github.com/proger/accelerated-scan + +# from tests on H800, with B, D = 16, 128, we see that the chunk can be greatly faster than the recurrent: +# +# Performance: +# seq_len chunk recurrent chunk_bwd recurrent_bwd +# 0 128.0 0.039360 0.061056 0.312160 0.205008 +# 1 256.0 0.045824 0.123712 0.308784 0.297696 +# 2 512.0 0.058688 0.241952 0.310720 0.626528 +# 3 1024.0 0.088288 0.476992 0.313184 1.333152 +# 4 2048.0 0.169472 0.943264 0.452464 2.724864 +# 5 4096.0 0.329920 1.886144 0.881600 5.551520 +# 6 8192.0 0.647872 3.755040 1.740496 11.117184 +# 7 16384.0 1.272064 7.520576 3.446608 22.362528 + +from typing import Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.autotune( + configs=[ + triton.Config({'BD': 32}, num_warps=1), + triton.Config({'BD': 32}, num_warps=2), + triton.Config({'BD': 32}, num_warps=4), + triton.Config({'BD': 32}, num_warps=8), + triton.Config({'BD': 64}, num_warps=1), + triton.Config({'BD': 64}, num_warps=2), + triton.Config({'BD': 64}, num_warps=4), + triton.Config({'BD': 64}, num_warps=8), + triton.Config({'BD': 128}, num_warps=1), + triton.Config({'BD': 128}, num_warps=2), + triton.Config({'BD': 128}, num_warps=4), + triton.Config({'BD': 128}, num_warps=8), + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_fwd_kernel_h( + x, + g, + gc, + o, + h0, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + i_b * T * D + i_t * BT * D + o_d + p_g = g + i_b * T * D + i_t * BT * D + o_d + p_gc = gc + i_b * T * D + i_t * BT * D + o_d + p_o = o + i_b * T * D + i_t * BT * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + b_gc = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + if i_t == 0: + b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32) + for i in range(0, BT): + mask_t = mask & ((i_t * BT + i) < T) + b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) + b_h = exp(b_g) * b_h + b_x + b_gc = b_gc + b_g + tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) + + p_x += D + p_g += D + p_gc += D + p_o += D + + +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_fwd_kernel_o( + gc, + o, + s_b, + s_t, + s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(1, tl.cdiv(T, BT)): + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) + # [BT, BD] + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_o = b_o + exp(b_gc) * b_h0[None, :] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_bwd_kernel_h( + g, + gc, + dx, + do, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + BC = min(BT, T - i_t * BT) + NT = tl.num_programs(1) + + p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d + p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d + + if i_t == NT - 1: + b_gc = tl.zeros([BD], dtype=tl.float32) + else: + b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) + b_dh = tl.zeros([BD], dtype=tl.float32) + for _ in range(BC - 1, -1, -1): + tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) + + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + + b_gc = b_gc + b_g + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * exp(b_g) + + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + + p_g -= D + p_gc -= D + p_dx -= D + p_do -= D + + +@triton.jit(do_not_specialize=['T']) +def chunk_hgrn_bwd_kernel_o( + g, + gc, + o, + dx, + dg, + s_b, + s_t, + s_d, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr +): + i_d, i_b = tl.program_id(0), tl.program_id(1) + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): + p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) + p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + + # [BD,] + mask_t = mask & ((i_t + 1) * BT < T) + b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32) + # [BT, BD] + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) + b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) + + b_dx = b_dx + exp(b_gc) * b_ht[None, :] + b_dg = b_o * b_dx * exp(b_g) + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkHGRNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, x, g, initial_state=None, output_final_state=False): + B, T, D = x.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + o = torch.empty_like(x, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) + chunk_hgrn_fwd_kernel_h[grid]( + x, g, gc, o, initial_state, + T=T, D=D, BT=BT, + USE_INITIAL_STATE=initial_state is not None + ) + def grid(meta): return (triton.cdiv(D, meta['BD']), B) + chunk_hgrn_fwd_kernel_o[grid]( + gc, o, + o.stride(-3), o.stride(-2), o.stride(-1), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + final_state = None + if output_final_state: + final_state = o[:, -1].clone() + o = o.to(x.dtype) + ctx.save_for_backward(g, o, initial_state) + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + B, T, D = do.shape + BT, BD = 128, min(64, triton.next_power_of_2(D)) + num_warps = 8 if BD == 64 else 4 + + gc = torch.empty_like(g, dtype=torch.float) + dx = torch.empty_like(o, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) + chunk_hgrn_bwd_kernel_h[grid]( + g, gc, dx, do, + T=T, D=D, BT=BT + ) + + dg = torch.empty_like(g, dtype=torch.float) + def grid(meta): return (triton.cdiv(D, meta['BD']), B) + chunk_hgrn_bwd_kernel_o[grid]( + g, gc, o, dx, dg, + o.stride(-3), o.stride(-2), o.stride(-1), + T=T, D=D, BT=BT, BD=BD, + num_warps=num_warps + ) + if initial_state is not None: + dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype) + + return dx.to(o.dtype), dg, None, None + + +@torch.compiler.disable +def chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state) diff --git a/fla/ops/hgrn/fused_recurrent.py b/fla/ops/hgrn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a70f0c7e4e12fc3648f1f0c19fc946fb85eb97 --- /dev/null +++ b/fla/ops/hgrn/fused_recurrent.py @@ -0,0 +1,308 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_hgrn_fwd_kernel( + x, + g, + o, + h0, + ht, + offsets, + T, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_x = x + bos * D + o_d + p_g = g + bos * D + o_d + p_o = o + bos * D + o_d + + b_h = tl.zeros([BD], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_n * D + o_d + b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32) + for _ in range(0, T): + b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32) + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_h = exp(b_g) * b_h + b_x + tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask) + + p_x += D + p_g += D + p_o += D + + if STORE_FINAL_STATE: + p_ht = ht + i_n * D + o_d + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['D'] +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_hgrn_bwd_kernel( + g, + o, + h0, + dx, + dg, + do, + dht, + dh0, + offsets, + T, + D: tl.constexpr, + BD: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_n = tl.program_id(0), tl.program_id(1) + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + + o_d = i_d * BD + tl.arange(0, BD) + mask = o_d < D + + p_g = g + (bos + T - 1) * D + o_d + p_o = o + (bos + T - 2) * D + o_d + p_dx = dx + (bos + T - 1) * D + o_d + p_dg = dg + (bos + T - 1) * D + o_d + p_do = do + (bos + T - 1) * D + o_d + + b_dh = tl.zeros([BD], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_n * D + o_d + b_dh += tl.load(p_dht, mask=mask, other=0).to(tl.float32) + + for i in range(T - 1, -1, -1): + b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) + if i > 0: + b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32) + elif USE_INITIAL_STATE: + b_o = tl.load(h0 + i_n * D + o_d, mask=mask, other=0).to(tl.float32) + else: + b_o = tl.zeros([BD], dtype=tl.float32) + + b_dh = b_dh + b_do + b_dx = b_dh + b_dh = b_dh * exp(b_g) + b_dg = b_dh * b_o + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask) + + p_g -= D + p_o -= D + p_dx -= D + p_dg -= D + p_do -= D + + if USE_INITIAL_STATE: + p_dh0 = dh0 + i_n * D + o_d + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask) + + +def fused_recurrent_hgrn_fwd( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, D = x.shape + N = B if offsets is None else len(offsets) - 1 + + o = torch.empty_like(x) + final_state = x.new_empty(N, D) if output_final_state else None + + def grid(meta): return (triton.cdiv(D, meta['BD']), N) + fused_recurrent_hgrn_fwd_kernel[grid]( + x=x, + g=g, + o=o, + h0=initial_state, + ht=final_state, + offsets=offsets, + T=T, + D=D + ) + return o, final_state + + +def fused_recurrent_hgrn_bwd( + g: torch.Tensor, + o: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor = None, + initial_state: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + B, T, D = do.shape + N = B if offsets is None else len(offsets) - 1 + + dx = torch.empty_like(o, dtype=torch.float) + dg = torch.empty_like(g, dtype=torch.float) + dh0 = torch.empty_like(initial_state, dtype=torch.float) if initial_state is not None else None + def grid(meta): return (triton.cdiv(D, meta['BD']), N) + fused_recurrent_hgrn_bwd_kernel[grid]( + g=g, + o=o, + h0=initial_state, + dx=dx, + dg=dg, + do=do, + dht=dht, + dh0=dh0, + offsets=offsets, + T=T, + D=D + ) + return dx, dg, dh0 + + +class FusedRecurrentHGRNFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward( + ctx, + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None + ): + o, ht = fused_recurrent_hgrn_fwd( + x=x, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets + ) + ctx.save_for_backward(g, o, initial_state) + ctx.offsets = offsets + return o, ht + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + g, o, initial_state = ctx.saved_tensors + offsets = ctx.offsets + + dx, dg, dh0 = fused_recurrent_hgrn_bwd( + g=g, + o=o, + do=do, + dht=dht, + initial_state=initial_state, + offsets=offsets + ) + return dx, dg, dh0, None, None + + +@torch.compiler.disable +def fused_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + x (torch.Tensor): + inputs of shape `[B, T, D]. + g (torch.Tensor): + Forget gates of shape `[B, T, D]`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, D]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, D]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, D]`. + final_state (torch.Tensor): + Final state of shape `[N, D]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.hgrn import fused_recurrent_hgrn + # inputs with equal lengths + >>> B, T, D = 4, 2048, 512 + >>> x = torch.randn(B, T, D, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, D, device='cuda')) + >>> h0 = torch.randn(B, D, device='cuda') + >>> o, ht = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> x, g = map(lambda x: rearrange(x, 'b t d -> 1 (b t) d'), (x, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = x.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_hgrn(x, g, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + return FusedRecurrentHGRNFunction.apply( + x, + g, + initial_state, + output_final_state, + cu_seqlens + ) diff --git a/fla/ops/hgrn/naive.py b/fla/ops/hgrn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcddc1967b31c5181d330704c7b5ff2127e9d68 --- /dev/null +++ b/fla/ops/hgrn/naive.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch + + +def naive_recurrent_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, T, D = x.shape + + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(T): + h = g[:, i].exp() * h + x[:, i] + o[:, i] = h + + if output_final_state: + final_state = h + return o.to(dtype), final_state + + +def naive_chunk_hgrn( + x: torch.Tensor, + g: torch.Tensor, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False, + chunk_size: int = 64 +) -> torch.Tensor: + dtype = x.dtype + x, g = map(lambda i: i.float(), (x, g)) + B, T, D = x.shape + + gc = g.view(B, chunk_size, D).cumsum(-2).view_as(g) + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + o = torch.zeros_like(x) + + final_state = None + if initial_state is not None: + h += initial_state + + for i in range(0, T, chunk_size): + hp = h + h = torch.zeros(B, D, dtype=torch.float, device=x.device) + for j in range(i, i + chunk_size): + h = g[:, j].exp() * h + x[:, j] + o[:, j] = hp * gc[:, j].exp() + h + h = o[:, j].clone() + + if output_final_state: + final_state = h + return o.to(dtype), final_state diff --git a/fla/ops/ttt/__init__.py b/fla/ops/ttt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e3d66e42f18785546b4e3be77abc2c91519e3bb9 --- /dev/null +++ b/fla/ops/ttt/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_ttt_linear +from .fused_chunk import fused_chunk_ttt_linear + +__all__ = [ + 'fused_chunk_ttt_linear', + 'chunk_ttt_linear' +] diff --git a/fla/ops/ttt/chunk.py b/fla/ops/ttt/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..6342364268cfc56b3b87601e902f0ab40d6a1b7f --- /dev/null +++ b/fla/ops/ttt/chunk.py @@ -0,0 +1,1539 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +from fla.modules.layernorm import group_norm +from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BT', 'BK', 'BV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_fwd_kernel_h( + k, + v, + v_new, + eta, + w, + b, + eps, + h, + hb, + h0, + hb0, + ht, + hbt, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + offs = tl.arange(0, BV) + b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.) + b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hb = tl.make_block_ptr(hb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hb = tl.make_block_ptr(hb + ((boh + i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_hb, b_hb.to(p_hb.dtype.element_ty), boundary_check=(0,)) + if HEAD_FIRST: + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((offs < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_eta_last = tl.load(p_eta_last) + b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_fwd_kernel_o( + q, + k, + v, + eta, + h, + hb, + o, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_bh * T * K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + o += (i_bh * T * V) if HEAD_FIRST else ((bos * H + i_h) * V) + h += ((i_bh * NT + i_t) * K * V) if HEAD_FIRST else ((i_tg * H + i_h) * K * V) + hb += ((i_bh * NT + i_t) * V) if HEAD_FIRST else ((i_tg * H + i_h) * V) + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_eta = 1 if HEAD_FIRST else H + + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (0, i_t * BT), (BK, BT), (0, 1)) + p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0)) + p_hb = tl.make_block_ptr(hb, (V,), (1,), (i_v * BV,), (BV,), (0,)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + # [BT, 1] + b_eta = tl.load(p_eta, boundary_check=(0,), padding_option="zero") + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero") + # [BV] + b_hb = tl.load(p_hb, boundary_check=(0,), padding_option="zero") + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o = tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + b_Ae = tl.where(m_A, b_eta[:, None], 0.0) + + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + b_o = (b_o - tl.dot(b_eta[:, None] * b_A.to(b_v.dtype), b_v, allow_tf32=False)) * scale + b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v.dtype), b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_bwd_kernel_h( + k, + v, + v_new, + eta, + w, + b, + eps, + h, + h0, + hb0, + x, + y, + r, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + offs = tl.arange(0, BV) + b_w = tl.load(w + i_h * V + offs, mask=offs < V, other=0.) + b_b = tl.load(b + i_h * V + offs, mask=offs < V, other=0.) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + if HEAD_FIRST: + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((offs < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((offs < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((offs < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_eta_last = tl.load(p_eta_last) + b_h = b_h - tl.dot(b_eta_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_eta_last * b_v2.to(b_k.dtype), axis=0) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [4] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_bwd_kernel_dv_local( + q, + k, + eta, + do, + dv, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + eta += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_eta = 1 if HEAD_FIRST else H + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + + p_eta = tl.make_block_ptr(eta, (T,), (stride_eta,), (i_t * BT,), (BT,), (0,)) + b_eta = tl.load(p_eta, boundary_check=(0,)) + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + b_A = - tl.where(mask, b_A * scale * b_eta[None, :], 0).to(do.dtype.element_ty) + b_Ae = - tl.where(mask, b_eta[None, :], 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_ttt_linear_bwd_kernel_norm( + q, + k, + v, + v_new, + x, + y, + r, + w, + b, + eta, + h, + dht, + dhbt, + dh0, + dhb0, + do, + dh, + dhb, + dv, + dv_new, + dk, + dw, + db, + offsets, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT_B: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_dhb = tl.zeros([BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero") + if USE_FINAL_STATE_GRADIENT_B: + p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero") + + # [BV] + offs_v = tl.arange(0, BV) + offs_t = tl.arange(0, BT) + b_w = tl.load(w + i_h * V + offs_v, mask=offs_v < V, other=0.) + b_b = tl.load(b + i_h * V + offs_v, mask=offs_v < V, other=0.) + b_dw = tl.zeros([BV,], dtype=b_w.dtype) + b_db = tl.zeros([BV,], dtype=b_b.dtype) + p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dhb = tl.make_block_ptr(dhb + (i_nh * NT + i_t) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + else: + p_h = tl.make_block_ptr(h + ((boh+i_t) * H + i_h) * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dhb = tl.make_block_ptr(dhb + ((boh+i_t) * H + i_h) * V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dhb, b_dhb.to(p_dhb.dtype.element_ty), boundary_check=(0,)) + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv_new = tl.make_block_ptr(dv_new + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_nh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r + i_nh * T, (T, 1), (1, 1), (i_t * BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1 + else: + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv_new = tl.make_block_ptr(dv_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, i_k * BK), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, i_v * BV), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_eta_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_dv_new = tl.load(p_dv_new, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_eta_last = tl.load(p_eta_last) + b_dv_new -= tl.dot(b_eta_last * b_k, b_dh.to(b_k.dtype)) + b_dv_new -= b_eta_last * b_dhb.to(b_k.dtype)[None, :] + + b_v_new = tl.load(p_v_new, boundary_check=(0, 1), padding_option="zero") + b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) - + b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) + + b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v_new.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True) + + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + b_w = b_w.to(b_k.dtype) + b_b = b_b.to(b_k.dtype) + b_dv = -b_w * b_dy.to(b_k.dtype) + b_dk = b_w * b_dy.to(b_k.dtype) + b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) + + (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype) + b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype) + b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype) + + # d_rstd, dx --> dkh --> dk, dh + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero") + b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero") + b_q = (b_q * scale).to(b_q.dtype) + b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) - + b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V + b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V + b_dkh = tl.where((offs_v < V)[None, :] * (offs_t < T-i_t*BT)[:, None], b_dkh, 0.) + b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype) + b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh) + b_dhb += tl.sum(b_do + b_dkh, axis=0) + b_dh = tl.where((offs_v < V)[None, :], b_dh, 0.) + b_dhb = tl.where((offs_v < V), b_dhb, 0.) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if USE_INITIAL_STATE_B: + p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (i_v * BV,), (BV,), (0,)) + tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dqke( + q, + k, + v, + e, + h, + do, + dh, + dhb, + dq, + dk, + de, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dhb += (i_bh * NT + i_t) * V if HEAD_FIRST else (i_tg * H + i_h) * V + q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + e += i_bh * T if HEAD_FIRST else (bos * H + i_h) + de += i_bh * T if HEAD_FIRST else (bos * H + i_h) + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_e = 1 if HEAD_FIRST else H + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_de = tl.zeros([BT,], dtype=tl.float32) + + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_e_last = (e + (i_t*BT+BT-1)*stride_e) if (i_t*BT+BT) <= T else (e + (T-1)*stride_e) + i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1) + mask = (tl.arange(0, BT) == i_last) + b_e_last = tl.load(p_e_last) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dhb = tl.make_block_ptr(dhb, (V,), (1,), (i_v * BV,), (BV,), (0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + # [BV] + b_dhb = tl.load(p_dhb, boundary_check=(0,)) + # [BT, BV] @ [BV, BT] -> [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk -= b_e_last * tl.dot(b_v, b_dh.to(b_v.dtype)) + b_de -= mask * tl.sum(tl.trans(b_dh) * tl.dot(tl.trans(b_k), b_v.to(b_k.dtype))) + b_de -= mask * tl.sum(b_dhb * tl.sum(b_v, axis=0).to(b_k.dtype)) + + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_e = tl.make_block_ptr(e, (T,), (stride_e,), (i_t * BT,), (BT,), (0,)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_e = tl.load(p_e, boundary_check=(0,)) + + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_de = tl.make_block_ptr(de, (T,), (stride_e,), (i_t * BT,), (BT,), (0,)) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq -= tl.dot(b_ds, b_k) * b_e[:, None] + b_dk -= tl.dot(tl.trans(b_ds), b_q * b_e[:, None]) * scale + b_de -= tl.sum(scale * tl.dot(b_ds, b_k) * b_q, axis=1) + b_de -= tl.sum(b_ds, axis=1) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,)) + + +def chunk_ttt_linear_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + eps: float, + initial_state: Optional[torch.Tensor] = None, + initial_state_bias: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + assert NV == 1, 'NV > 1 is not supported by TTT update rule.' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + hb = k.new_empty(B, H, NT, 1, V) + else: + h = k.new_empty(B, NT, H, K, V) + hb = k.new_empty(B, NT, H, 1, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(v) + grid = (NK, NV, N * H) + + chunk_ttt_linear_fwd_kernel_h[grid]( + k=k, + v=v, + v_new=v_new, + eta=eta, + w=w, + b=b, + eps=eps, + h=h, + hb=hb, + h0=initial_state, + hb0=initial_state_bias, + ht=final_state, + hbt=final_state_bias, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, hb, v_new, final_state, final_state_bias + + +def chunk_ttt_linear_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + eta: torch.Tensor, + h: torch.Tensor, + hb: torch.Tensor, + scale: Optional[float] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + assert NV == 1, 'NV > 1 is not supported by TTT update rule.' + + o = torch.empty_like(v) + + grid = (NV, NT, B * H) + chunk_ttt_linear_fwd_kernel_o[grid]( + q, + k, + v, + eta, + h, + hb, + o, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return o + + +def chunk_ttt_linear_bwd_h( + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + eps: float, + initial_state: Optional[torch.Tensor] = None, + initial_state_bias: Optional[torch.Tensor] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + assert NV == 1, 'NV > 1 is not supported by TTT update rule.' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + rstd = v.new_empty(B, H, T, 1, dtype=torch.float32) + else: + h = k.new_empty(B, NT, H, K, V) + rstd = v.new_empty(B, T, H, 1, dtype=torch.float32) + x = torch.empty_like(v) + y = torch.empty_like(v) + + v_new = torch.empty_like(v) + grid = (NK, NV, N * H) + + chunk_ttt_linear_bwd_kernel_h[grid]( + k=k, + v=v, + v_new=v_new, + eta=eta, + w=w, + b=b, + eps=eps, + h=h, + h0=initial_state, + hb0=initial_state_bias, + x=x, + y=y, + r=rstd, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, x, y, rstd + + +def chunk_ttt_linear_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + eta: torch.Tensor, + do: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + BK = min(triton.next_power_of_2(K), 128) + BV = min(triton.next_power_of_2(V), 128) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_ttt_linear_bwd_kernel_dv_local[grid]( + q, + k, + eta, + do, + dv, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_ttt_linear_bwd_norm( + q: torch.Tensor, # [B, H, L, D] + k: torch.Tensor, # [B, H, L, D] + v: torch.Tensor, # [B, H, L, D] + v_new: torch.Tensor, # [B, H, L, D] + x: torch.Tensor, # [B, H, L, D] + y: torch.Tensor, # [B, H, L, D] + rstd: torch.Tensor, # [B, H, L, 1] + w: torch.Tensor, # [H, D] + b: torch.Tensor, # [H, D] + eta: torch.Tensor, # [B, H, L, 1] + h0: torch.Tensor, # [B, H, D, D] + hb0: torch.Tensor, # [B, H, 1, D] + h: torch.Tensor, # [B, H, NT, D, D] + dht: Optional[torch.Tensor], # [B, H, D, D] + dhbt: Optional[torch.Tensor], # [B, H, 1, D] + dv_new: Optional[torch.Tensor], # [B, H, L, D] + do: torch.Tensor, # [B, H, L, D] + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # torch implementation of `dkh, dw, db, dk, dv` for LN^2 + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *q.shape, do.shape[-1] + else: + B, T, H, K, V = *q.shape, do.shape[-1] + BT = chunk_size + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + + BK = triton.next_power_of_2(K) + BV = triton.next_power_of_2(V) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported by TTT.' + assert NV == 1, 'NV > 1 is not supported by TTT.' + + if head_first: + dh = q.new_empty(B, H, NT, K, V) + dhb = q.new_empty(B, H, NT, 1, V) + else: + dh = q.new_empty(B, NT, H, K, V) + dhb = q.new_empty(B, NT, H, 1, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dhb0 = torch.empty_like(hb0, dtype=torch.float32) if hb0 is not None else None + dv = torch.empty_like(v) + dk = torch.empty_like(k) + dw = w.new_empty(B, H, V) + db = b.new_empty(B, H, V) + + grid = (NK, NV, N * H) + chunk_ttt_linear_bwd_kernel_norm[grid]( + q=q, + k=k, + v=v, + v_new=v_new, + x=x, + y=y, + r=rstd, + w=w, + b=b, + eta=eta, + h=h, + dht=dht, + dhbt=dhbt, + dh0=dh0, + dhb0=dhb0, + do=do, + dh=dh, + dhb=dhb, + dv=dv, + dv_new=dv_new, + dk=dk, + dw=dw, + db=db, + offsets=offsets, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + dw = dw.sum(dim=0) + db = db.sum(dim=0) + return dh, dhb, dh0, dhb0, dv, dk, dw, db + + +def chunk_ttt_linear_bwd_norm_ref( + q: torch.Tensor, # [B, H, L, D] + k: torch.Tensor, # [B, H, L, D] + v: torch.Tensor, # [B, H, L, D] + v_new: torch.Tensor, # [B, H, L, D] + kh: torch.Tensor, # [B, H, L, D] + y: torch.Tensor, # [B, H, L, D] + w: torch.Tensor, # [H, D] + b: torch.Tensor, # [H, D] + eta: torch.Tensor, # [B, H, L, 1] + h0: torch.Tensor, # [B, H, D, D] + h: torch.Tensor, # [B, H, NT, D, D] + dht: Optional[torch.Tensor], # [B, H, D, D] + dv_new: Optional[torch.Tensor], # [B, H, L, D] + do: torch.Tensor, # [B, H, L, D] + scale: float, + eps: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # torch implementation of `dkh, dw, db, dk, dv` for LN^2 + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *q.shape, do.shape[-1] + else: + B, T, H, K, V = *q.shape, do.shape[-1] + # [B, L, H, D] -> [B, H, L, D] + q, k, v, v_new, kh, y, h, eta, dv_new, do = [ + x.transpose(1, 2) for x in + [q, k, v, v_new, kh, y, h, eta, dv_new, do] + ] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + pad_len = (BT - (T % BT)) % BT + if pad_len > 0: + q, k, v, v_new, kh, y, eta, dv_new, do = [ + F.pad(x, (0, 0, 0, pad_len)) for x in + [q, k, v, v_new, kh, y, eta, dv_new, do] + ] + eta[:, :, -1, :] = eta[:, :, -(pad_len+1), :] + # [NT, B, H, BT, D] + q, k, v, v_new, kh, y, eta, dv_new, do = [ + x.reshape(B, H, NT, BT, -1).permute(2, 0, 1, 3, 4) for x in + [q, k, v, v_new, kh, y, eta, dv_new, do] + ] + h = h.permute(2, 0, 1, 3, 4) + + # allocate + dh = q.new_zeros(NT, B, H, K, V) + dv = torch.zeros_like(v) + dk = torch.zeros_like(k) + dw = torch.zeros_like(w) + db = torch.zeros_like(b) + # recurrent state + b_dh = dht if dht is not None else torch.zeros_like(dh[0]) + b_dh = b_dh.to(torch.float32) + + # [H, 1, D] + _w = w.reshape(H, 1, V).to(torch.float32) + _b = b.reshape(H, 1, V).to(torch.float32) + + # d_state passing + for i_t in range(NT - 1, -1, -1): + dh[i_t] = b_dh.to(dh.dtype) + # [B, H, BT, D] + _q, _k, _v, _v_new, _kh, _y, _h, _eta, _dv_new, _do = [ + x[i_t].to(torch.float32) for x in + (q, k, v, v_new, kh, y, h, eta, dv_new, do) + ] + _dv_new -= (_eta[:, :, -1, :, None] * _k) @ b_dh + + mean = _kh.mean(dim=-1, keepdim=True) + var = _kh.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = 1 / torch.sqrt(var + eps).to(torch.float32) + x = (_kh - mean) * rstd + # [B, H, BT, D] + dy = rstd * (_dv_new*V - _dv_new.sum(dim=-1, keepdim=True) - x*(x*_dv_new).sum(dim=-1, keepdim=True)) / V + dx = -rstd * (_dv_new*(x*_y).sum(dim=-1, keepdim=True) + _y*(x*_dv_new).sum(dim=-1, keepdim=True)) / V + d_rstd = (_dv_new * _v_new / rstd).sum(dim=-1, keepdim=True) + + dv[i_t] = (-_w*dy).to(dv.dtype) + dk[i_t] += (_w*dy).to(dk.dtype) + dw += (2*_w*x*dy+(_b-_v+_k)*dy).sum(dim=(0, 2)).to(dw.dtype) + db += (_w*dy).sum(dim=(0, 2)).to(db.dtype) + dx += _w*_w*dy + + # d_rstd, dx --> dkh --> dk, dh + dkh = rstd * (V * dx - dx.sum(dim=-1, keepdim=True) - x * (x * dx).sum(dim=-1, keepdim=True)) / V + dkh -= rstd**2 * d_rstd * x / V + dk[i_t] += (dkh @ _h.transpose(-2, -1)).to(dk.dtype) + b_dh += (_q.transpose(-2, -1) * scale) @ _do + _k.transpose(-2, -1) @ dkh + dh0 = b_dh.to(torch.float32) if h0 is not None else None + + # [NT, B, H, BT, D] -> [B, H, T, D] + dv = dv.permute(1, 2, 0, 3, 4).reshape(B, H, -1, V)[:, :, :T, :] + dk = dk.permute(1, 2, 0, 3, 4).reshape(B, H, -1, K)[:, :, :T, :] + # [B, H, NT, D, D] + dh = dh.permute(1, 2, 0, 3, 4) + if not head_first: + dv, dk, dh = [x.transpose(1, 2) for x in (dv, dk, dh)] + dh, dv, dk, dw, db = [x.contiguous() for x in (dh, dv, dk, dw, db)] + dh0 = dh0.contiguous() if h0 is not None else None + return dh, dh0, dv, dk, dw, db + + +def chunk_ttt_linear_bwd_dqke( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + eta: torch.Tensor, + h: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dhb: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 16, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + BK = triton.next_power_of_2(K) + BV = min(triton.next_power_of_2(V), 64) + NK = triton.cdiv(K, BK) + assert NK == 1, "NK > 1 is not supported." + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + de = torch.empty_like(eta) + grid = (NK, NT, B * H) + + chunk_bwd_kernel_dqke[grid]( + q=q, + k=k, + v=v, + e=eta, + h=h, + do=do, + dh=dh, + dhb=dhb, + dq=dq, + dk=dk, + de=de, + offsets=offsets, + indices=indices, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dq, dk, de + + +def chunk_ttt_linear_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + initial_state: torch.Tensor, + initial_state_bias: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + BT: int = 16 +): + h, hb, v_new, final_state, final_state_bias = chunk_ttt_linear_fwd_h( + k=k, + v=v, + w=w, + b=b, + eta=eta, + eps=eps, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + o = chunk_ttt_linear_fwd_o( + q=q, + k=k, + v=v_new, + eta=eta, + h=h, + hb=hb, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + return o, final_state, final_state_bias + + +def chunk_ttt_linear_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + do: torch.Tensor, + dht: torch.Tensor, + dhbt: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True +): + h, v_new, x, y, rstd = chunk_ttt_linear_bwd_h( + k=k, + v=v, + w=w, + b=b, + eta=eta, + eps=eps, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dv_new = chunk_ttt_linear_bwd_dv_local( + q=q, + k=k, + eta=eta, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dh, dhb, dh0, dhb0, dv, dk, dw, db = chunk_ttt_linear_bwd_norm( + q=q, + k=k, + v=v, + v_new=v_new, + x=x, + y=y, + rstd=rstd, + w=w, + b=b, + eta=eta, + h0=initial_state, + hb0=initial_state_bias, + h=h, + dht=dht, + dhbt=dhbt, + dv_new=dv_new, + do=do, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dq, dk2, de = chunk_ttt_linear_bwd_dqke( + q=q, + k=k, + v=v_new, + eta=eta, + h=h, + do=do, + dh=dh, + dhb=dhb, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=BT + ) + dk.add_(dk2) + return dq, dk, dv, de, dw, db, dh0, dhb0 + + +class ChunkTTTLinearFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state, + initial_state_bias, output_final_state, offsets, head_first): + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, BT) if offsets is not None else None + o, final_state, final_state_bias = chunk_ttt_linear_fwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=scale, + eps=eps, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + ) + ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias) + ctx.BT = BT + ctx.scale = scale + ctx.eps = eps + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + return o.to(q.dtype), final_state, final_state_bias + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht, dhbt): + q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors + dq, dk, dv, de, dw, db, dh0, dhb0 = chunk_ttt_linear_bwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=ctx.scale, + eps=ctx.eps, + do=do, + dht=dht, + dhbt=dhbt, + BT=ctx.BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=ctx.offsets, + indices=ctx.indices, + head_first=ctx.head_first + ) + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None + + +def norm_residual(x, weight, bias, eps, head_first): + # GroupNorm and Residual + if head_first: + B, H, T, D = x.shape + x = x.transpose(1, 2) + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + x = x.transpose(1, 2) + else: + B, T, H, D = x.shape + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + return x + + +def chunk_ttt_linear( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float = None, + eps: float = 1e-6, + chunk_size: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True, +): + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + layer norm weight of shape `(H, V)` + b (torch.Tensor): + layer norm bias of shape `(H, V)` + eta (torch.Tensor): + Learning rate for hidden state, of shape `(B, H, T, 1)`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + chunk_size (int): + chunk size. Default: `16`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + initial_state_bias (Optional[torch.Tensor]): + Initial state bias of shape `(B, H, 1, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "DK must equal to DV." + if isinstance(eta, float): + eta = torch.full_like(q[:, :, :, :1], eta) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "Scale must be positive." + o, final_state, final_state_bias = ChunkTTTLinearFunction.apply( + q, + k, + v, + w, + b, + chunk_size, + eta, + scale, + eps, + initial_state, + initial_state_bias, + output_final_state, + cu_seqlens, + head_first, + ) + o = norm_residual(o, w, b, eps, head_first) + return o, final_state, final_state_bias diff --git a/fla/ops/ttt/fused_chunk.py b/fla/ops/ttt/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..08850c170b3e88fa98f1818434baab7d73e93c7a --- /dev/null +++ b/fla/ops/ttt/fused_chunk.py @@ -0,0 +1,896 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.modules.layernorm import group_norm +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_ttt_linear_fwd_kernel( + q, + k, + v, + eta, + w, + b, + o, + scale, + eps, + h0, + hb0, + ht, + hbt, + offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + + o_i = tl.arange(0, BT) + v_i = tl.arange(0, BV) + m_A = o_i[:, None] >= o_i[None, :] + b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.) + b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + # [BT, BV] + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((v_i < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + # [BT] + b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero") + b_q = (b_q * scale).to(b_k.dtype) + + # [BT, BT] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + b_A = tl.where(m_A, b_A, 0) + b_Ae = tl.where(m_A, b_e[:, None], 0.0) + + b_o = - tl.dot(b_e[:, None] * b_A.to(b_v2.dtype), b_v2, allow_tf32=False) + b_o += b_hb[None, :] - tl.dot(b_Ae.to(b_v2.dtype), b_v2, allow_tf32=False) + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_e_last = tl.load(p_e_last) + b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0) + b_h = tl.where((v_i < V)[None, :], b_h, 0.) + b_hb = tl.where((v_i < V), b_hb, 0.) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + p_hbt = tl.make_block_ptr(hbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_hbt, b_hb.to(p_hbt.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['hb0'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_ttt_linear_bwd_kernel_h( + k, + v, + v2, + x, + y, + r, + w, + b, + eta, + h0, + hb0, + h, + do, + dq, + scale, + eps, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + bos, _ = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + o_i = tl.arange(0, BT) + v_i = tl.arange(0, BV) + m_A = o_i[:, None] >= o_i[None, :] + b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.) + b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.) + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_hb = tl.zeros([BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K * V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + if USE_INITIAL_STATE_B: + p_hb0 = tl.make_block_ptr(hb0 + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + b_hb = tl.load(p_hb0, boundary_check=(0,), padding_option="zero").to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_dq = tl.make_block_ptr(dq+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e_last = eta+i_nh*T+T-1 if i_t == NT-1 else eta+i_nh*T+i_t*BT+BT-1 + else: + p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_dq = tl.make_block_ptr(dq+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + + b_kh = tl.dot(tl.trans(b_k), b_h.to(b_k.dtype), allow_tf32=False).to(tl.float32) + b_hb[None, :] + b_kh = tl.where((v_i < V)[None, :], b_kh, 0.) + mean = tl.sum(b_kh, axis=1, keep_dims=True) / V + xbar = tl.where((v_i < V)[None, :], b_kh - mean, 0.) + var = tl.sum(xbar * xbar, axis=1, keep_dims=True) / V + rstd = 1 / tl.sqrt(var.to(tl.float32) + eps) + b_kh_hat = (b_kh - mean) * rstd + + b_v = b_kh_hat.to(b_k.dtype) * b_w[None, :].to(b_k.dtype) + \ + b_b[None, :].to(b_k.dtype) - b_v.to(b_k.dtype) + tl.trans(b_k) + b_v = tl.where((v_i < V)[None, :], b_v * b_w[None, :].to(b_k.dtype), 0.) + b_v2 = rstd * (V * b_v - tl.sum(b_v, axis=1, keep_dims=True) - b_kh_hat.to(b_k.dtype) + * tl.sum(b_v * b_kh_hat.to(b_k.dtype), axis=1, keep_dims=True)) / V + tl.store(p_x, b_kh_hat.to(p_x.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_y, b_v.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_r, rstd.to(p_r.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_v2, b_v2.to(p_v2.dtype.element_ty), boundary_check=(0, 1)) + + b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero") + b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero") + + b_v2 = tl.where((v_i < V)[None, :], b_v2, 0.) + b_ds = tl.dot(b_do, tl.trans(b_v2).to(b_do.dtype)) + b_ds = tl.where(m_A, b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq = tl.dot(b_do, tl.trans(b_h).to(b_do.dtype)) + b_dq -= tl.dot(b_ds, tl.trans(b_k)) * b_e[:, None] + b_dq *= scale + + b_e_last = tl.load(p_e_last) + b_h = b_h - tl.dot(b_e_last * b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_hb = b_hb - tl.sum(b_e_last * b_v2.to(b_k.dtype), axis=0) + b_h = tl.where((v_i < V)[None, :], b_h, 0.) + b_hb = tl.where((v_i < V), b_hb, 0.) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_INITIAL_STATE_B': lambda args: args['dhb0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_FINAL_STATE_GRADIENT_B': lambda args: args['dhbt'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4) + ], + key=['BT', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def fused_chunk_ttt_linear_bwd_kernel_dh( + q, + k, + v, + v2, + x, + y, + r, + w, + b, + eta, + h, + dht, + dhbt, + dh0, + dhb0, + do, + dk, + dv, + de, + dw, + db, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_INITIAL_STATE_B: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT_B: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + bos, _ = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + # [BV] + b_dhb = tl.zeros([BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1), padding_option="zero") + if USE_FINAL_STATE_GRADIENT_B: + p_dhbt = tl.make_block_ptr(dhbt + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + b_dhb += tl.load(p_dhbt, boundary_check=(0,), padding_option="zero") + + # [BV] + o_i = tl.arange(0, BT) + v_i = tl.arange(0, BV) + m_A = o_i[:, None] >= o_i[None, :] + m_A_t = o_i[:, None] <= o_i[None, :] + b_w = tl.load(w + i_h * V + v_i, mask=v_i < V, other=0.) + b_b = tl.load(b + i_h * V + v_i, mask=v_i < V, other=0.) + b_dw = tl.zeros([BV,], dtype=b_w.dtype) + b_db = tl.zeros([BV,], dtype=b_b.dtype) + p_dw = tl.make_block_ptr(dw + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + p_db = tl.make_block_ptr(db + i_nh * V, (V,), (1,), (0,), (BV,), (0,)) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h+(i_nh*NT+i_t)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1)) + p_q = tl.make_block_ptr(q+i_nh*T*K, (K, T), (1, K), (0, i_t*BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+i_nh*T, (T, 1), (1, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_dv = tl.make_block_ptr(dv+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk+i_nh*T*K, (T, K), (K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+i_nh*T*V, (T, V), (V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_de = tl.make_block_ptr(de+i_nh*T, (T,), (1,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta + i_nh*T + T - 1 if i_t == NT-1 else eta + i_nh*T + i_t*BT + BT - 1 + else: + p_h = tl.make_block_ptr(h+((boh+i_t)*H+i_h)*K*V, (V, K), (1, V), (0, 0), (BV, BK), (0, 1)) + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (0, i_t*BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_v2 = tl.make_block_ptr(v2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_x = tl.make_block_ptr(x+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_y = tl.make_block_ptr(y+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_r = tl.make_block_ptr(r+bos*H+i_h, (T, 1), (H, 1), (i_t*BT, 0), (BT, 1), (1, 0)) + p_e = tl.make_block_ptr(eta+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t*BT, 0), (BT, BK), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT, 0), (BT, BV), (1, 0)) + p_de = tl.make_block_ptr(de+(bos*H+i_h), (T,), (H,), (i_t*BT,), (BT,), (0,)) + p_e_last = eta+bos*H+i_h + (T-1)*H if i_t == NT-1 else eta+bos*H+i_h + (i_t*BT+BT-1)*H + b_q = tl.load(p_q, boundary_check=(0, 1), padding_option="zero") + b_k = tl.load(p_k, boundary_check=(0, 1), padding_option="zero") + b_e = tl.load(p_e, boundary_check=(0,), padding_option="zero") + b_do = tl.load(p_do, boundary_check=(0, 1), padding_option="zero") + b_e_last = tl.load(p_e_last) + b_A = tl.dot(b_k, b_q) + b_A = - tl.where(m_A_t, b_A * scale * b_e[None, :], 0).to(do.dtype.element_ty) + b_Ae = - tl.where(m_A_t, b_e[None, :], 0).to(do.dtype.element_ty) + b_dv_new = tl.dot(b_A.to(b_do.dtype), b_do) + tl.dot(b_Ae.to(b_do.dtype), b_do) + b_dv_new -= tl.dot(b_e_last * b_k, b_dh.to(b_k.dtype)) + b_dv_new -= b_e_last * b_dhb.to(b_k.dtype)[None, :] + + b_v2 = tl.load(p_v2, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_x = tl.load(p_x, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_y = tl.load(p_y, boundary_check=(0, 1), padding_option="zero").to(b_k.dtype) + b_rstd = tl.load(p_r, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + b_dy = b_rstd * (b_dv_new * V - tl.sum(b_dv_new, axis=1, keep_dims=True) - + b_x * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_dx = -b_rstd * (b_dv_new * tl.sum(b_x * b_y, axis=1, keep_dims=True) + + b_y * tl.sum(b_dv_new * b_x, axis=1, keep_dims=True)) / V + b_drstd = tl.sum(b_dv_new.to(b_rstd.dtype) * b_v2.to(b_rstd.dtype) / b_rstd, axis=1, keep_dims=True) + + b_v = tl.load(p_v, boundary_check=(0, 1), padding_option="zero") + b_w = b_w.to(b_k.dtype) + b_b = b_b.to(b_k.dtype) + b_dv = -b_w * b_dy.to(b_k.dtype) + b_dk = b_w * b_dy.to(b_k.dtype) + b_dw += tl.sum(2 * b_w * b_x * b_dy.to(b_k.dtype) + + (b_b - b_v.to(b_k.dtype) + b_k) * b_dy.to(b_k.dtype), axis=0).to(b_dw.dtype) + b_db += tl.sum(b_w * b_dy.to(b_k.dtype), axis=0).to(b_db.dtype) + b_dx = b_dx.to(b_k.dtype) + b_w * b_w * b_dy.to(b_k.dtype) + + b_h = tl.load(p_h, boundary_check=(0, 1), padding_option="zero") + b_q = (b_q * scale).to(b_q.dtype) + b_dkh = b_rstd * (V * b_dx - tl.sum(b_dx, axis=1, keep_dims=True) - + b_x * tl.sum(b_x * b_dx, axis=1, keep_dims=True)) / V + b_dkh -= b_rstd * b_rstd * b_drstd * b_x / V + b_dkh = tl.where((v_i < V)[None, :] * (o_i < T-i_t*BT)[:, None], b_dkh, 0.) + b_dk += tl.dot(b_dkh, b_h.to(b_dkh.dtype)).to(b_k.dtype) + + b_ds = tl.dot(b_do, tl.trans(b_v2)) + b_ds = tl.where(m_A, b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + i_last = (BT-1) if (i_t*BT+BT) <= T else (T % BT-1) + mask = (o_i == i_last) + b_dk -= b_e_last * tl.dot(b_v2, tl.trans(b_dh).to(b_v2.dtype)) + b_dk -= tl.dot(tl.trans(b_ds), tl.trans(b_q) * b_e[:, None]) + b_de = mask * tl.sum(- b_dh * tl.trans(tl.dot(tl.trans(b_v2), b_k))).to(b_k.dtype) + b_de -= mask * tl.sum(b_dhb * tl.sum(b_v2, axis=0)).to(b_k.dtype) + b_de -= tl.sum(tl.dot(b_ds, b_k) * tl.trans(b_q).to(b_k.dtype), axis=1) + b_de -= tl.sum(b_ds, axis=1) + b_dh += tl.dot(b_q, b_do.to(b_q.dtype)) + tl.dot(tl.trans(b_k).to(b_dkh.dtype), b_dkh) + b_dhb += tl.sum(b_do + b_dkh, axis=0) + b_dh = tl.where((v_i < V)[None, :], b_dh, 0.) + b_dhb = tl.where((v_i < V), b_dhb, 0.) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_de, b_de.to(p_de.dtype.element_ty), boundary_check=(0,)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0,)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0+i_nh*K*V, (K, V), (V, 1), (0, 0), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if USE_INITIAL_STATE_B: + p_dhb0 = tl.make_block_ptr(dhb0+i_nh*V, (V,), (1,), (0,), (BV,), (0,)) + tl.store(p_dhb0, b_dhb.to(p_dhb0.dtype.element_ty), boundary_check=(0,)) + + +def fused_chunk_ttt_linear_bwd_h( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + do: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + N, NT = B, triton.cdiv(T, BT) + BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + + if head_first: + h = k.new_empty(B, H, NT, K, V) + r = v.new_empty(B, H, T, 1, dtype=torch.float32) + else: + h = k.new_empty(B, NT, H, K, V) + r = v.new_empty(B, T, H, 1, dtype=torch.float32) + v2 = torch.empty_like(v) + x = torch.empty_like(v) + y = torch.empty_like(v) + dq = torch.empty_like(q) + + grid = (N * H,) + fused_chunk_ttt_linear_bwd_kernel_h[grid]( + k=k, + v=v, + v2=v2, + x=x, + y=y, + r=r, + w=w, + b=b, + eta=eta, + h0=initial_state, + hb0=initial_state_bias, + h=h, + do=do, + dq=dq, + scale=scale, + eps=eps, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dq, h, v2, x, y, r + + +def fused_chunk_ttt_linear_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + v2: torch.Tensor, + x: torch.Tensor, + y: torch.Tensor, + r: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + h: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + dhbt: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + assert offsets is None, "bwd of varlen is not implemented yet." + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + N = B + BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + + dh0 = torch.empty_like(initial_state, dtype=torch.float32) if initial_state is not None else None + dhb0 = torch.empty_like(initial_state_bias, dtype=torch.float32) if initial_state_bias is not None else None + dk = torch.empty_like(k) + dv = torch.empty_like(v) + de = torch.empty_like(eta) + dw = w.new_empty(B, H, V) + db = b.new_empty(B, H, V) + + grid = (N * H,) + fused_chunk_ttt_linear_bwd_kernel_dh[grid]( + q=q, + k=k, + v=v, + v2=v2, + x=x, + y=y, + r=r, + w=w, + b=b, + eta=eta, + h=h, + dht=dht, + dhbt=dhbt, + dh0=dh0, + dhb0=dhb0, + do=do, + dk=dk, + dv=dv, + de=de, + dw=dw, + db=db, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + dw = dw.sum(dim=0) + db = db.sum(dim=0) + return dk, dv, de, dw, db, dh0, dhb0 + + +def fused_chunk_ttt_linear_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + initial_state: torch.Tensor, + initial_state_bias: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True, + BT: int = 16 +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + N = B if offsets is None else len(offsets) - 1 + BK, BV = triton.next_power_of_2(K), triton.next_power_of_2(V) + assert max(BK, BV) <= 128, "current kernel does not support head dimension larger than 128." + o = torch.empty_like(v) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + final_state_bias = k.new_empty(N, H, 1, V, dtype=torch.float32) if output_final_state else None + + grid = (N * H,) + fused_chunk_ttt_linear_fwd_kernel[grid]( + q=q, + k=k, + v=v, + eta=eta, + w=w, + b=b, + o=o, + scale=scale, + eps=eps, + h0=initial_state, + hb0=initial_state_bias, + ht=final_state, + hbt=final_state_bias, + offsets=offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return o, final_state, final_state_bias + + +def fused_chunk_ttt_linear_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + do: torch.Tensor, + dht: torch.Tensor, + dhbt: torch.Tensor, + BT: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + assert offsets is None, "bwd of varlen is not implemented yet." + dq, h, v2, x, y, rstd = fused_chunk_ttt_linear_bwd_h( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=scale, + eps=eps, + do=do, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=offsets, + head_first=head_first + ) + dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd_dh( + q=q, + k=k, + v=v, + v2=v2, + x=x, + y=y, + r=rstd, + w=w, + b=b, + eta=eta, + scale=scale, + h=h, + do=do, + dht=dht, + dhbt=dhbt, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=offsets, + head_first=head_first + ) + return dq, dk, dv, de, dw, db, dh0, dhb0 + + +class FusedChunkTTTLinearFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, w, b, BT, eta, scale, eps, initial_state, + initial_state_bias, output_final_state, offsets, head_first): + o, final_state, final_state_bias = fused_chunk_ttt_linear_fwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=scale, + eps=eps, + BT=BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + output_final_state=output_final_state, + offsets=offsets, + head_first=head_first + ) + ctx.save_for_backward(q, k, v, eta, w, b, initial_state, initial_state_bias) + ctx.BT = BT + ctx.scale = scale + ctx.eps = eps + ctx.offsets = offsets + ctx.head_first = head_first + return o.to(q.dtype), final_state, final_state_bias + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht, dhbt): + q, k, v, eta, w, b, initial_state, initial_state_bias = ctx.saved_tensors + dq, dk, dv, de, dw, db, dh0, dhb0 = fused_chunk_ttt_linear_bwd( + q=q, + k=k, + v=v, + w=w, + b=b, + eta=eta, + scale=ctx.scale, + eps=ctx.eps, + do=do, + dht=dht, + dhbt=dhbt, + BT=ctx.BT, + initial_state=initial_state, + initial_state_bias=initial_state_bias, + offsets=ctx.offsets, + head_first=ctx.head_first + ) + return dq.to(q), dk.to(k), dv.to(v), dw.to(w), db.to(b), None, de.to(eta), None, None, dh0, dhb0, None, None, None + + +def norm_residual(x, weight, bias, eps, head_first): + # GroupNorm and Residual + if head_first: + B, H, T, D = x.shape + x = x.transpose(1, 2) + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + x = x.transpose(1, 2) + else: + B, T, H, D = x.shape + x += group_norm( + x.reshape(B, T, -1).clone(), + weight=weight.reshape(-1).clone(), + bias=bias.reshape(-1).clone(), + eps=eps, + num_groups=H, + ).reshape(x.shape) + return x + + +def fused_chunk_ttt_linear( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float = None, + eps: float = 1e-6, + chunk_size: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True, +): + r""" + Args: + q (torch.Tensor): + queries of shape `(B, H, T, K)` + k (torch.Tensor): + keys of shape `(B, H, T, K)` + v (torch.Tensor): + values of shape `(B, H, T, V)` + w (torch.Tensor): + layer norm weight of shape `(H, V)` + b (torch.Tensor): + layer norm bias of shape `(H, V)` + eta (torch.Tensor): + Learning rate for hidden state, of shape `(B, H, T, 1)`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + chunk_size (int): + chunk size. Default: `16`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `(B, H, K, V)`. Default: `None`. + initial_state_bias (Optional[torch.Tensor]): + Initial state bias of shape `(B, H, 1, V)`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `(B, H, K, V)`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None`. + final_state_bias (torch.Tensor): + Final state bias of shape `[B, H, 1, V]` if `output_final_state=True` else `None`. + """ + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "DK must equal to DV." + if isinstance(eta, float): + eta = torch.full_like(q[:, :, :, :1], eta) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "Scale must be positive." + o, final_state, final_state_bias = FusedChunkTTTLinearFunction.apply( + q, + k, + v, + w, + b, + chunk_size, + eta, + scale, + eps, + initial_state, + initial_state_bias, + output_final_state, + cu_seqlens, + head_first + ) + o = norm_residual(o, w, b, eps, head_first) + return o, final_state, final_state_bias diff --git a/fla/ops/ttt/naive.py b/fla/ops/ttt/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..0ad5dbba89989f6bfa7b13278f93e506f72a691a --- /dev/null +++ b/fla/ops/ttt/naive.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang, Yuqi Pan + +import torch +import torch.nn.functional as F + + +def ttt_linear( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float, + eps: float, + mini_batch_size: int, + initial_state: torch.Tensor, + initial_state_bias: torch.Tensor, + output_final_state: bool +): + B, H, T, D = q.shape + BT = mini_batch_size + NT = T // BT + # [NT, B, H, mini_batch_size, D] + _q = q.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) + _k = k.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) + _v = v.reshape(B, H, NT, BT, D).permute(2, 0, 1, 3, 4) + # [NT, B, H, BT, 1] + _eta = eta.reshape(B, H, NT, BT, 1).permute(2, 0, 1, 3, 4) + # [H, 1, D] + w = w.reshape(H, 1, D).to(torch.float32) + b = b.reshape(H, 1, D).to(torch.float32) + + h = torch.zeros((B, H, D, D), device=v.device, dtype=torch.float32) if initial_state is None else initial_state + hb = torch.zeros((B, H, 1, D), device=v.device, dtype=torch.float32) if initial_state_bias is None else initial_state_bias + q *= scale + # [NT, B, H, BT, D] + o = torch.empty_like(_v) + + for i in range(NT): + q_i, k_i, v_i, eta_i = [x[i] for x in [_q, _k, _v, _eta]] + kh = k_i @ h + hb + reconstruction_target = v_i - k_i + + mean = kh.mean(-1, True) + var = kh.var(-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + kh_hat = (kh - mean) / rstd + + g = w * kh_hat + b - reconstruction_target + g *= w + v_new = (D * g - g.sum(-1, True) - kh_hat * (g * kh_hat).sum(-1, True)) / (rstd * D) + + Attn = torch.tril(q_i @ k_i.transpose(-2, -1)) + o_i = q_i @ h - (eta_i * Attn) @ v_new + hb - torch.tril(eta_i.expand_as(Attn)) @ v_new + h = h - (eta_i[:, :, -1, :, None] * k_i).transpose(-1, -2) @ v_new + hb = hb - torch.sum(eta_i[:, :, -1, :, None] * v_new, dim=-2, keepdim=True) + # layer norm with residuals + + mean = o_i.mean(dim=-1, keepdim=True) + var = o_i.var(dim=-1, unbiased=False, keepdim=True).to(torch.float32) + rstd = torch.sqrt(var + eps).to(torch.float32) + o[i] = o_i + (o_i - mean) / rstd * w + b + + # [B, H, T, D] + o = o.permute(1, 2, 0, 3, 4).reshape(B, H, T, D) + h = h if output_final_state else None + hb = hb if output_final_state else None + return o, h, hb + + +def chunk_ttt_linear_ref( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + w: torch.Tensor, + b: torch.Tensor, + eta: torch.Tensor, + scale: float = None, + eps: float = 1e-6, + mini_batch_size: int = 16, + initial_state: torch.Tensor = None, + initial_state_bias: torch.Tensor = None, + output_final_state: bool = False, + head_first: bool = True, +): + assert q.dtype == k.dtype == v.dtype + assert k.shape[-1] == v.shape[-1], "The key and value dimension must be the same." + if isinstance(eta, float): + eta = torch.full_like(q[:, :, :, :1], eta) + if scale is None: + scale = k.shape[-1] ** -0.5 + if not head_first: + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + eta = eta.transpose(1, 2) + T = q.shape[-2] + padded = (mini_batch_size - (T % mini_batch_size)) % mini_batch_size + if padded > 0: + q = F.pad(q, (0, 0, 0, padded)) + k = F.pad(k, (0, 0, 0, padded)) + v = F.pad(v, (0, 0, 0, padded)) + eta = F.pad(eta, (0, 0, 0, padded)) + eta[:, :, -1, :] = eta[:, :, -(padded+1), :] + assert q.shape[-2] % mini_batch_size == 0, "Sequence length should be a multiple of mini_batch_size." + q, k, v, eta, w, b = map(lambda x: x.to(torch.float32), [q, k, v, eta, w, b]) + o, final_state, final_state_bias = ttt_linear( + q, + k, + v, + w, + b, + eta, + scale, + eps, + mini_batch_size, + initial_state, + initial_state_bias, + output_final_state, + ) + o = o[:, :, :T, :].contiguous() + if not head_first: + o = o.transpose(1, 2) + return o, final_state, final_state_bias diff --git a/fla/ops/utils/__init__.py b/fla/ops/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4b0ff5fcf03073efdcf657043ecdd482c8eec1 --- /dev/null +++ b/fla/ops/utils/__init__.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- + +from .asm import fp32_to_tf32_asm +from .cumsum import ( + chunk_global_cumsum, + chunk_global_cumsum_scalar, + chunk_global_cumsum_scalar_kernel, + chunk_global_cumsum_vector, + chunk_global_cumsum_vector_kernel, + chunk_local_cumsum, + chunk_local_cumsum_scalar, + chunk_local_cumsum_scalar_kernel, + chunk_local_cumsum_vector, + chunk_local_cumsum_vector_kernel +) +from .logcumsumexp import logcumsumexp_fwd_kernel +from .logsumexp import logsumexp_fwd, logsumexp_fwd_kernel +from .matmul import addmm, matmul, matmul_kernel +from .pooling import mean_pooling +from .softmax import softmax_bwd, softmax_bwd_kernel, softmax_fwd, softmax_fwd_kernel + +__all__ = [ + 'chunk_global_cumsum', + 'chunk_global_cumsum_scalar', + 'chunk_global_cumsum_scalar_kernel', + 'chunk_global_cumsum_vector', + 'chunk_global_cumsum_vector_kernel', + 'chunk_local_cumsum', + 'chunk_local_cumsum_scalar', + 'chunk_local_cumsum_scalar_kernel', + 'chunk_local_cumsum_vector', + 'chunk_local_cumsum_vector_kernel', + 'logcumsumexp_fwd_kernel', + 'logsumexp_fwd', + 'logsumexp_fwd_kernel', + 'addmm', + 'matmul', + 'matmul_kernel', + 'mean_pooling', + 'softmax_bwd', + 'softmax_bwd_kernel', + 'softmax_fwd', + 'softmax_fwd_kernel', + 'fp32_to_tf32_asm', +] diff --git a/fla/ops/utils/__pycache__/__init__.cpython-311.pyc b/fla/ops/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7f2e6cff6acbcbc972bc370c6c3f517b883805 Binary files /dev/null and b/fla/ops/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/asm.cpython-311.pyc b/fla/ops/utils/__pycache__/asm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c37e0579b7cfa2f16b94cf378db0b65e738aa96 Binary files /dev/null and b/fla/ops/utils/__pycache__/asm.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/cumsum.cpython-311.pyc b/fla/ops/utils/__pycache__/cumsum.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be1454685a633cdbd4946d205a49a02f77e8f53c Binary files /dev/null and b/fla/ops/utils/__pycache__/cumsum.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/logcumsumexp.cpython-311.pyc b/fla/ops/utils/__pycache__/logcumsumexp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b4f495f028893e451c4b0b278831c80e8adfd08 Binary files /dev/null and b/fla/ops/utils/__pycache__/logcumsumexp.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/logsumexp.cpython-311.pyc b/fla/ops/utils/__pycache__/logsumexp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..128cd8977951eb95d0cf536d2cde049022021210 Binary files /dev/null and b/fla/ops/utils/__pycache__/logsumexp.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/matmul.cpython-311.pyc b/fla/ops/utils/__pycache__/matmul.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..033e4ec4084309afc49d7098e48546d293e58ca0 Binary files /dev/null and b/fla/ops/utils/__pycache__/matmul.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/op.cpython-311.pyc b/fla/ops/utils/__pycache__/op.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cf5dd3c7663a92a1c1fd09a269a468014f227e0 Binary files /dev/null and b/fla/ops/utils/__pycache__/op.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/pooling.cpython-311.pyc b/fla/ops/utils/__pycache__/pooling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f54e9fe07b6b8cac25240135b1ccb8b03d71e0a Binary files /dev/null and b/fla/ops/utils/__pycache__/pooling.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/softmax.cpython-311.pyc b/fla/ops/utils/__pycache__/softmax.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6beaaaf1bfe4c7669acb6366dc49cc2555b7fbd8 Binary files /dev/null and b/fla/ops/utils/__pycache__/softmax.cpython-311.pyc differ diff --git a/fla/ops/utils/__pycache__/solve_tril.cpython-311.pyc b/fla/ops/utils/__pycache__/solve_tril.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b516e91cda9d033e4b2618341922944df641b85d Binary files /dev/null and b/fla/ops/utils/__pycache__/solve_tril.cpython-311.pyc differ diff --git a/fla/ops/utils/asm.py b/fla/ops/utils/asm.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a96bad2cecf24733832b6817f8d4b855685f05 --- /dev/null +++ b/fla/ops/utils/asm.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +from fla.utils import device_platform + + +def fp32_to_tf32_asm() -> str: + """ + Get the assembly code for converting FP32 to TF32. + """ + ASM_DICT = { + 'nvidia': 'cvt.rna.tf32.f32 $0, $1;' + } + if device_platform in ASM_DICT: + return ASM_DICT[device_platform] + else: + # return empty string if the device is not supported + return "" diff --git a/fla/ops/utils/cumsum.py b/fla/ops/utils/cumsum.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5f3e90d39566507f01660040bfdb1986d25adb --- /dev/null +++ b/fla/ops/utils/cumsum.py @@ -0,0 +1,400 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_scalar_kernel( + s, + o, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BS': BS}, num_warps=num_warps) + for BS in BS_LIST + for num_warps in [2, 4, 8] + ], + key=['S', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_local_cumsum_vector_kernel( + s, + o, + offsets, + indices, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BT': 16}, num_warps=2), + triton.Config({'BT': 32}, num_warps=4), + triton.Config({'BT': 32}, num_warps=2), + triton.Config({'BT': 64}, num_warps=8), + triton.Config({'BT': 64}, num_warps=4), + ], + key=[] +) +@triton.jit(do_not_specialize=['T']) +def chunk_global_cumsum_scalar_kernel( + s, + o, + offsets, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_bh = tl.program_id(0) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_z = tl.zeros([], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT-1-i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_ss = tl.sum(b_s, 0) + if REVERSE: + b_o = -b_o + b_ss + b_s + b_o += b_z + if i_c >= 0: + b_z += b_ss + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for BT in [16, 32, 64] + for num_warps in [2, 4, 8] + ], + key=['S'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_global_cumsum_vector_kernel( + s, + z, + offsets, + T, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + REVERSE: tl.constexpr +): + i_s, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_z = tl.zeros([BS], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT-1-i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) + if i_c >= 0: + b_z += tl.sum(b_s, 0) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + if offsets is not None: + B = len(offsets) - 1 + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + offsets, + indices, + T=T, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) + # keep cummulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + offsets, + indices, + T=T, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return g + + +@input_guard +def chunk_global_cumsum_scalar( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + dtype = dtype or s.dtype + if head_first: + B, H, T = s.shape + else: + B, T, H = s.shape + if offsets is not None: + B = len(offsets) - 1 + grid = (B * H,) + z = torch.empty_like(s, dtype=output_dtype or dtype) + chunk_global_cumsum_scalar_kernel[grid]( + s, + z, + offsets, + T=T, + H=H, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return z + + +@input_guard +def chunk_global_cumsum_vector( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + dtype = dtype or s.dtype + if head_first: + B, H, T, S = s.shape + else: + B, T, H, S = s.shape + BS = min(32, triton.next_power_of_2(S)) + if offsets is not None: + B = len(offsets) - 1 + grid = (triton.cdiv(S, BS), B * H) + z = torch.empty_like(s, dtype=output_dtype or dtype) + chunk_global_cumsum_vector_kernel[grid]( + s, + z, + offsets, + T=T, + H=H, + S=S, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse + ) + return z + + +@input_guard +def chunk_global_cumsum( + s: torch.Tensor, + dtype: Optional[torch.dtype] = None, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if offsets is not None: + assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" + if len(s.shape) == 3: + return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype) + else: + raise ValueError(f"Unsupported input shape {s.shape}. " + f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` " + f"or [B, T, H]/[B, T, H, D] otherwise") + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + output_dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + if offsets is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) + else: + raise ValueError(f"Unsupported input shape {g.shape}. " + f"which should be (B, H, T, dim) if `head_first=True` " + f"or (batch_size, num_heads, seq_len) otherwise") diff --git a/fla/ops/utils/logcumsumexp.py b/fla/ops/utils/logcumsumexp.py new file mode 100644 index 0000000000000000000000000000000000000000..7476d3220599aa78e8f8ae5b10d0d15297cc47b4 --- /dev/null +++ b/fla/ops/utils/logcumsumexp.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log + + +@triton.autotune( + configs=[ + triton.Config({'BT': BT}, num_warps=num_warps) + for BT in [16, 32, 64] + for num_warps in [2, 4, 8] + ], + key=['S'] +) +@triton.jit(do_not_specialize=['T']) +def logcumsumexp_fwd_kernel( + s, + z, + T, + S: tl.constexpr, + BT: tl.constexpr +): + i_bh = tl.program_id(0) + o_i = tl.arange(0, BT) + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) + + b_mp = tl.full([S,], float('-inf'), dtype=tl.float32) + b_zp = tl.zeros([S,], dtype=tl.float32) + for i_t in range(tl.cdiv(T, BT)): + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, 0), (BT, S), (1, 0)) + + # [BT, S] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + # [S,] + b_mc = tl.max(b_s, 0) + b_mc = tl.maximum(b_mp, b_mc) + b_zp = b_zp * exp(b_mp - b_mc) + # [BT, S] + b_s = exp(b_s - b_mc) + b_z = tl.dot(m_s, b_s, allow_tf32=False) + b_zp + # [S,] + b_zc = tl.max(b_z, 0) + b_mp = b_mc + b_zp = b_zc + # [BT, BS] + # small eps to prevent underflows + b_z = log(tl.where(b_z != 0, b_z, 1e-20)) + b_mc + tl.store(p_z, b_z.to(p_z.dtype.element_ty), boundary_check=(0, 1)) diff --git a/fla/ops/utils/logsumexp.py b/fla/ops/utils/logsumexp.py new file mode 100644 index 0000000000000000000000000000000000000000..b647012b68c05ee59783d3d3615961962895a185 --- /dev/null +++ b/fla/ops/utils/logsumexp.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, log + + +@triton.heuristics({ + 'HAS_SCALE': lambda args: args['scale'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4, 8, 16, 32] + ], + key=['D'] +) +@triton.jit +def logsumexp_fwd_kernel( + x, + z, + scale, + D: tl.constexpr, + B: tl.constexpr, + HAS_SCALE: tl.constexpr +): + i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) + o_d = i_d * B + tl.arange(0, B) + m_d = o_d < D + + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) + if HAS_SCALE: + b_x = b_x * scale + b_m = tl.max(b_x, 0) + b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m + tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) + + +def logsumexp_fwd( + x, + scale: Optional[float] = None, + dtype: Optional[torch.dtype] = None +): + r""" + Compute the logsumexp of the input tensor over the last dimension. + + Args: + x (Tensor): + The input tensor of any shape. + scale (Optional[float]): + The scale applied to the input tensor. Default: `None`. + dtype (Optional[torch.dtype]): + The data type of the output tensor. Default: `None`. + Returns: + Tensor: The logsumexp of the input tensor. + """ + + shape = x.shape + x = x.view(-1, shape[-1]) + N, D = x.shape + B = min(triton.next_power_of_2(D), 64 * 1024) + ND = triton.cdiv(D, B) + + z = x.new_empty(N, ND, dtype=torch.float) + logsumexp_fwd_kernel[(N, ND)]( + x=x, + z=z, + scale=scale, + D=D, + B=B + ) + z = z.logsumexp(-1).view(*shape[:-1]) + if dtype is not None and dtype != torch.float: + z = z.to(dtype) + return z diff --git a/fla/ops/utils/matmul.py b/fla/ops/utils/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..994bcecd237c721eb1c2d8511b0f15d5d0aa804d --- /dev/null +++ b/fla/ops/utils/matmul.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# code adapted from +# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BM`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.heuristics({ + 'HAS_ALPHA': lambda args: args['alpha'] is not None, + 'HAS_BETA': lambda args: args['beta'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BM': 128, 'BK': 64, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8), + triton.Config({'BM': 64, 'BK': 32, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 128, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 128, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 64, 'BK': 32, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 128, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4), + triton.Config({'BM': 64, 'BK': 32, 'BN': 32, 'G': 4}, num_stages=5, num_warps=2), + triton.Config({'BM': 32, 'BK': 32, 'BN': 64, 'G': 4}, num_stages=5, num_warps=2), + # Good config for fp8 inputs. + # triton.Config({'BM': 128, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=3, num_warps=8), + # triton.Config({'BM': 256, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=3, num_warps=8), + # triton.Config({'BM': 256, 'BK': 128, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 64, 'BK': 128, 'BN': 256, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 128, 'BK': 128, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 128, 'BK': 64, 'BN': 64, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 64, 'BK': 64, 'BN': 128, 'G': 4}, num_stages=4, num_warps=4), + # triton.Config({'BM': 128, 'BK': 64, 'BN': 32, 'G': 4}, num_stages=4, num_warps=4) + ], + key=['M', 'N', 'K'] +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a, + b, + c, + input, + alpha, + beta, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `s_am` is how much to increase `a` + # by to get the element one row down (A has M rows). + stride_ab, stride_am, stride_ak, # a: batch, M, K + stride_bk, stride_bn, # b: K, N + stride_cb, stride_cm, stride_cn, # c: batch, M, N + # Meta-parameters + BM: tl.constexpr, + BK: tl.constexpr, + BN: tl.constexpr, + G: tl.constexpr, + ACTIVATION: tl.constexpr, + HAS_INPUT: tl.constexpr, + HAS_ALPHA: tl.constexpr, + HAS_BETA: tl.constexpr, + ALLOW_TF32: tl.constexpr, + X_DIM: tl.constexpr = 1, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + i_b, i_m, i_n = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NM, NN = tl.num_programs(1), tl.num_programs(2) + i_m, i_n = tl.swizzle2d(i_m, i_n, NM, NN, G) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `p_a` is a block of [BM, BK] pointers + # `p_b` is a block of [BK, BN] pointers + # See above `Pointer Arithmetic` section for details + a_batch_ptr = a + i_b * stride_ab + o_am = (i_m * BM + tl.arange(0, BM)) % M + o_bn = (i_n * BN + tl.arange(0, BN)) % N + o_k = tl.arange(0, BK) + + p_a = a_batch_ptr + (o_am[:, None] * stride_am + o_k[None, :] * stride_ak) + p_b = b + (o_k[:, None] * stride_bk + o_bn[None, :] * stride_bn) + + b_acc = tl.zeros((BM, BN), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BK)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + b_a = tl.load(p_a, mask=o_k[None, :] < K - k * BK, other=0.0) + b_b = tl.load(p_b, mask=o_k[:, None] < K - k * BK, other=0.0) + # We accumulate along the K dimension. + b_acc = tl.dot(b_a, b_b, acc=b_acc, allow_tf32=ALLOW_TF32) + # Advance the ptrs to the next K block. + p_a += BK * stride_ak + p_b += BK * stride_bk + + o_cm = i_m * BM + tl.arange(0, BM) + o_cn = i_n * BN + tl.arange(0, BN) + mask = (o_cm[:, None] < M) & (o_cn[None, :] < N) + + b_c = b_acc + # You can fuse arbitrary activation functions here + # while the b_acc is still in FP32! + if ACTIVATION == "leaky_relu": + b_c = leaky_relu(b_c) + elif ACTIVATION == "relu": + b_c = relu(b_c) + elif ACTIVATION == "sigmoid": + b_c = sigmoid(b_c) + elif ACTIVATION == "tanh": + b_c = tanh(b_c) + + if HAS_ALPHA: + b_c *= tl.load(alpha) + + if HAS_INPUT: + p_i = input + (stride_cm * o_cm[:, None] if X_DIM == 2 else 0) + stride_cn * o_cn[None, :] + mask_p = (o_cn[None, :] < N) if X_DIM == 1 else mask + b_i = tl.load(p_i, mask=mask_p, other=0.0).to(tl.float32) + if HAS_BETA: + b_i *= tl.load(beta) + b_c += b_i + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + c_batch_ptr = c + i_b * stride_cb + p_c = c_batch_ptr + stride_cm * o_cm[:, None] + stride_cn * o_cn[None, :] + tl.store(p_c, b_c.to(c.dtype.element_ty), mask=mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +@triton.jit +def sigmoid(x): + # σ(x) = 1 / (1 + exp(-x)) + return 1.0 / (1.0 + exp(-x)) + + +@triton.jit +def tanh(x): + # tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) + # 2 * sigmoid(2x) - 1 + return (exp(x) - exp(-x)) / (exp(x) + exp(-x)) + + +@triton.jit +def relu(x): + # ReLU(x) = max(0, x) + return tl.maximum(x, 0.0) + + +@input_guard +def matmul(a, b, activation=''): + assert a.dim() in [2, 3], "a must be 2D or 3D" + assert b.dim() == 2, "b must be 2D" + assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}" + + if a.dim() == 2: + a_dim = 2 + a = a.unsqueeze(0).contiguous() # (1, M, K) + else: + a_dim = 3 + allow_tf32 = False if a.dtype == torch.float32 else True + + B, M, K = a.shape[0], a.shape[1], a.shape[2] + K_b, N = b.shape + assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}" + c = a.new_empty(B, M, N) + + def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN'])) + matmul_kernel[grid]( + a, b, c, None, None, None, + M, N, K, + a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak + b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2) + c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn + ACTIVATION=activation, + ALLOW_TF32=allow_tf32, + HAS_INPUT=False, + ) + return c.squeeze(0) if a_dim == 2 else c + + +@input_guard +def addmm( + x: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + alpha: Optional[float] = None, + beta: Optional[float] = None, +) -> torch.Tensor: + assert a.dim() in [2, 3], "a must be 2D or 3D" + assert b.dim() == 2, "b must be 2D" + assert a.shape[-1] == b.shape[0], f"Incompatible dimensions: A {a.shape}, B {b.shape}" + + if a.dim() == 2: + a_dim = 2 + a = a.unsqueeze(0).contiguous() # (1, M, K) + else: + a_dim = 3 + allow_tf32 = False if a.dtype == torch.float32 else True + + B, M, K = a.shape[0], a.shape[1], a.shape[2] + K_b, N = b.shape + assert K == K_b, f"Incompatible K dimension: A {K} vs B {K_b}" + c = a.new_empty(B, M, N) + + def grid(meta): return (B, triton.cdiv(M, meta['BM']), triton.cdiv(N, meta['BN'])) + matmul_kernel[grid]( + a, b, c, x, alpha, beta, + M, N, K, + a.stride(0), a.stride(1), a.stride(2), # stride_ab, stride_am, stride_ak + b.stride(0), b.stride(1), # stride_bk, stride_bn (b.dim() == 2) + c.stride(0), c.stride(1), c.stride(2), # stride_cb, stride_cm, stride_cn + ACTIVATION=None, + ALLOW_TF32=allow_tf32, + HAS_INPUT=True, + X_DIM=x.dim(), + ) + return c.squeeze(0) if a_dim == 2 else c diff --git a/fla/ops/utils/op.py b/fla/ops/utils/op.py new file mode 100644 index 0000000000000000000000000000000000000000..f0fe269ed8756b6a7b3ea396dffdfdd56b924ea9 --- /dev/null +++ b/fla/ops/utils/op.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +import os + +import triton +import triton.language as tl +import triton.language.extra.libdevice as tldevice + +from fla.utils import is_gather_supported + +if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': + div = tldevice.fast_dividef + exp = tldevice.fast_expf + log = tldevice.fast_logf + log2 = tldevice.fast_log2f +else: + @triton.jit + def div_normal(x, y): + return x / y + div = div_normal + exp = tl.exp + log = tl.log + log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + return exp(tl.where(x <= 0, x, float('-inf'))) + + +if not is_gather_supported: + def gather(*args, **kwargs): + pass +else: + gather = tl.gather diff --git a/fla/ops/utils/pooling.py b/fla/ops/utils/pooling.py new file mode 100644 index 0000000000000000000000000000000000000000..0dd9059b4abd0a87fb65e25c01fd5897452f77e0 --- /dev/null +++ b/fla/ops/utils/pooling.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_indices +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_fwd_kernel( + x, + o, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BT, BD] + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + # [BD] + b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BD': BD}, num_warps=num_warps) + for BD in [16, 32, 64, 128] + for num_warps in [1, 2, 4, 8] + ], + key=['BT'] +) +@triton.jit(do_not_specialize=['T']) +def mean_pooling_bwd_kernel( + do, + dx, + offsets, + indices, + T: tl.constexpr, + H: tl.constexpr, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, + NT: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) + p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) + # [BD] + b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32) + # [BT, BD] + b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None] + tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) + + +def mean_pooling_fwd( + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = x.shape + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = x.new_empty(B, NT, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_fwd_kernel[grid]( + x, + o, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT, + ) + return o + + +def mean_pooling_bwd( + do: torch.Tensor, + batch_size: int, + seq_len: int, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None +) -> torch.Tensor: + B, T, H, D = batch_size, seq_len, *do.shape[-2:] + BT = chunk_size + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dx = do.new_empty(B, T, H, D) + def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) + mean_pooling_bwd_kernel[grid]( + do, + dx, + offsets, + indices, + T=T, + H=H, + D=D, + BT=BT, + NT=NT + ) + return dx + + +class MeanPoolingFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + x: torch.Tensor, + chunk_size: int, + offsets: Optional[torch.LongTensor] = None + ) -> torch.Tensor: + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + o = mean_pooling_fwd(x, chunk_size, offsets, indices) + ctx.batch_size = x.shape[0] + ctx.seq_len = x.shape[1] + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + return o + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward( + ctx, do + ) -> Tuple[torch.Tensor, None, None]: + batch_size = ctx.batch_size + seq_len = ctx.seq_len + chunk_size = ctx.chunk_size + offsets = ctx.offsets + indices = ctx.indices + dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, offsets, indices) + return dx, None, None + + +def mean_pooling( + x: torch.Tensor, + chunk_size: int, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + if head_first: + x = x.transpose(1, 2) + if cu_seqlens is not None: + if x.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) + if head_first: + o = o.transpose(1, 2) + return o diff --git a/fla/ops/utils/softmax.py b/fla/ops/utils/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..12c37c7a57061c8d8dfd2ab6a31b2dc33547607f --- /dev/null +++ b/fla/ops/utils/softmax.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32) + ], + key=['D'] +) +@triton.jit +def softmax_fwd_kernel( + x, + p, + D: tl.constexpr, + B: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < D + + b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) + b_m = tl.max(b_x, 0) + b_x = exp(b_x - b_m) + b_p = b_x / tl.sum(b_x, 0) + + tl.store(p + i_n * D + o_d, b_p.to(p.dtype.element_ty), mask=m_d) + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32) + ], + key=['D'] +) +@triton.jit +def softmax_bwd_kernel( + p, + dp, + ds, + D: tl.constexpr, + B: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < D + + b_p = tl.load(p + i_n * D + o_d, mask=m_d, other=0.) + b_dp = tl.load(dp + i_n * D + o_d, mask=m_d, other=0.) + b_pp = tl.sum(b_p * b_dp, 0) + b_ds = b_p * b_dp - b_p * b_pp + tl.store(ds + i_n * D + o_d, b_ds.to(ds.dtype.element_ty), mask=m_d) + + +def softmax_fwd( + x: torch.Tensor, + dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + shape = x.shape + x = x.view(-1, x.shape[-1]) + + N, D = x.shape + B = triton.next_power_of_2(D) + + p = torch.empty_like(x, dtype=dtype) + softmax_fwd_kernel[(N,)]( + x=x, + p=p, + D=D, + B=B + ) + return p.view(*shape) + + +def softmax_bwd( + p: torch.Tensor, + dp: torch.Tensor, + dtype: Optional[torch.dtype] = torch.float +) -> torch.Tensor: + shape = p.shape + p = p.view(-1, p.shape[-1]) + ds = torch.empty_like(p, dtype=dtype) + + N, D = p.shape + B = triton.next_power_of_2(D) + softmax_bwd_kernel[(N,)]( + p=p, + dp=dp, + ds=ds, + D=D, + B=B + ) + return ds.view(*shape) diff --git a/fla/ops/utils/solve_tril.py b/fla/ops/utils/solve_tril.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c2b66833c4d1479ec8c25ae82a85e69e96650a --- /dev/null +++ b/fla/ops/utils/solve_tril.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_indices +from fla.utils import input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['BT'], +) +@triton.jit(do_not_specialize=['T']) +def solve_tril_16x16_kernel( + A, + Ad, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + A = A + i_bh * T * BT + Ad = Ad + i_bh * T * 16 + stride_16 = 16 + stride_BT = BT + else: + A = A + (bos*H + i_h) * BT + Ad = Ad + (bos*H + i_h) * 16 + stride_16 = H*16 + stride_BT = H*BT + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (stride_BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T-i_t*16)): + b_a = -tl.load(A + (i_t * 16 + i) * stride_BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ad, + Ai, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + A += (i_bh * T * 32) + Ad += (i_bh * T * 16) + Ai += (i_bh * T * 32) + stride_16 = 16 + stride_32 = 32 + else: + A += (bos*H + i_h) * 32 + Ad += (bos*H + i_h) * 16 + Ai += (bos*H + i_h) * 32 + stride_16 = 16 * H + stride_32 = 32 * H + + p_A_21 = tl.make_block_ptr(A, (T, 32), (stride_32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (stride_32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee') + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=['H', 'BT', 'HEAD_FIRST', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ad, + Ai, + offsets, + indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + A += i_bh * T * 64 + Ad += i_bh * T * 16 + Ai += i_bh * T * 64 + stride_16 = 16 + stride_64 = 64 + else: + A += (bos*H + i_h) * 64 + Ad += (bos*H + i_h) * 16 + Ai += (bos*H + i_h) * 64 + stride_16 = 16 * H + stride_64 = 64 * H + + p_A_21 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (stride_64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (stride_16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'), Ai_11, input_precision='ieee') + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'), Ai_22, input_precision='ieee') + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'), Ai_33, input_precision='ieee') + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision='ieee') + + tl.dot(A_32, Ai_21, input_precision='ieee'), + input_precision='ieee' + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision='ieee') + + tl.dot(A_43, Ai_32, input_precision='ieee'), + input_precision='ieee' + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision='ieee') + + tl.dot(A_42, Ai_21, input_precision='ieee') + + tl.dot(A_43, Ai_31, input_precision='ieee'), + input_precision='ieee' + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (stride_64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] if head_first else [B, H, T, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + head_first (bool): + If False, the input/output tensor is in the shape of [B, T, H, K]. + If True, the input/output tensor is in the shape of [B, H, T, K]. + Default: False + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + assert A.dtype == torch.float, "A should be float32." + + if head_first: + B, H, T, BT = A.shape + Ad = torch.empty(B, H, T, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + else: + B, T, H, BT = A.shape + Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + + indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + offsets=cu_seqlens, + indices=indices, + T=T, + H=H, + BT=BT, + HEAD_FIRST=head_first, + ) + if BT == 16: + return Ad + + if head_first: + Ai = torch.zeros(B, H, T, BT, device=A.device, dtype=output_dtype) + else: + Ai = torch.zeros(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + offsets=cu_seqlens, + indices=indices, + T=T, + H=H, + BT=BT, + HEAD_FIRST=head_first, + USE_OFFSETS=cu_seqlens is not None + ) + return Ai diff --git a/fla/ops/utils/testing.py b/fla/ops/utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4fb01202e6bfbea3351defe8a424ca648ce7d1 --- /dev/null +++ b/fla/ops/utils/testing.py @@ -0,0 +1,26 @@ +import os + +compiled_mode = os.getenv("COMPILER_MODE") == "1" +ci_env = os.getenv("CI_ENV") == "1" + + +def get_abs_err(x, y): + return (x.detach()-y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x-y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / (base + 1e-15) + + +def assert_close(prefix, ref, tri, ratio, warning=False): + msg = f"{prefix} diff: {get_abs_err(ref, tri):.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + error_rate = get_err_ratio(ref, tri) + if warning or str(prefix).strip().lower() == "dh0" or (ci_env and error_rate < 0.01): + if error_rate > ratio: + import warnings + warnings.warn(msg) + else: + assert error_rate < ratio, msg