diff --git a/fla/ops/abc/__init__.py b/fla/ops/abc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdac8d900fc51485a55716443ee1f00424b522b9 --- /dev/null +++ b/fla/ops/abc/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_abc + +__all__ = [ + 'chunk_abc' +] diff --git a/fla/ops/abc/__pycache__/__init__.cpython-311.pyc b/fla/ops/abc/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad3ed038b741b5395eeb1a3db1e33ce20833a961 Binary files /dev/null and b/fla/ops/abc/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/abc/__pycache__/chunk.cpython-311.pyc b/fla/ops/abc/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a7a62b584a55b644e72c040c54fdc9da21c6a4f Binary files /dev/null and b/fla/ops/abc/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/abc/chunk.py b/fla/ops/abc/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..8538e04800cd71414782ff72668df1fbd97984b1 --- /dev/null +++ b/fla/ops/abc/chunk.py @@ -0,0 +1,1116 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import logcumsumexp_fwd_kernel, softmax_bwd, softmax_fwd +from fla.ops.utils.op import exp +from fla.utils import input_guard + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_h( + k, + v, + z, + h, + h0, + ht, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if NORMK: + p_z0 = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_k * BK,), (BK,), (0,)) + else: + p_z0 = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_z0).to(tl.float32) + for i_t in range(NT): + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + if NORMK: + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[:, None] + b_k = exp(b_k - b_zc[:, None]).to(b_k.dtype) + else: + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zp - b_zc), b_zc + # [BK, BV] + b_h = b_h * b_r[None, :] + b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype) + # [BK, BV] + b_h += tl.dot(b_k, b_v, allow_tf32=False) + + if STORE_FINAL_STATE: + p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_intra_K( + v, + z, + o, + A, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_o = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(0, i_i): + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A, exp(b_v - b_zn[None, :]).to(b_v.dtype), allow_tf32=False) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_o *= exp(b_zn[None, :] - b_z) + + o_i = tl.arange(0, BC) + o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(A + o_A + j, mask=m_A, other=0) + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC, BV] + # avoid 0 * inf = inf + m_i = o_i[:, None] >= j + b_o += tl.where(m_i, b_A[:, None] * exp(b_v[None, :] - b_z), 0) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_K( + q, + k, + z, + h, + o, + A, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_o += tl.dot(b_q, b_h, allow_tf32=False) + # [BT, BT] + b_A += tl.dot(b_q, b_k, allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BV] + p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_o = b_o * exp(b_zp[None, :] - b_z) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_A = tl.where(m_s, b_A, 0.) + if i_v == 0: + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_intra_V( + q, + k, + z, + A, + scale, + T, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * exp(b_zn[None, :] - b_z) * scale).to(b_q.dtype) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = exp(b_k - b_zn[:, None]).to(b_k.dtype) + # [BC, BC] + b_A = tl.dot(b_q, b_k, allow_tf32=False) + tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_j * BC) * K + i_k * BK,), (BK,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_k * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BK,] + b_k = tl.load(p_k, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_A = tl.sum(b_q * exp(b_k[None, :] - b_z) * scale, 1) + b_A = tl.where(o_i >= j, b_A, 0.) + tl.store(A + o_A + j, b_A.to(b_q.dtype), mask=m_A) + + p_k = tl.advance(p_k, (K,)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_fwd_kernel_V( + q, + v, + z, + h, + o, + A, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + # [BT, BK] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_q = (b_q * exp(b_zp[None, :] - b_z)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_q, b_h, allow_tf32=False) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_o += tl.dot(b_A.to(b_v.dtype), b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_dh( + q, + z, + do, + dh, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + NORMK: tl.constexpr +): + i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + b_zp = tl.full([BK if NORMK else BV], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + i_p = tl.maximum(i_t * BT - 1, 0) + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + if NORMK: + p_z = tl.make_block_ptr(z + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zc - b_zp), b_zc + # [BK, BT] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_q = (b_q * exp(b_zc[:, None] - b_z)).to(b_q.dtype) + # [BK, BV] + b_dh = b_dh * b_r[:, None] + else: + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + b_r, b_zp = exp(b_zc - b_zp), b_zc + # [BT, BV] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = (b_do * exp(b_zc[None, :] - b_z)).to(b_do.dtype) + # [BK, BV] + b_dh = b_dh * b_r[None, :] + # [BK, BV] + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_V( + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) + + # [BK,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = exp(b_k - b_zc[None, :]).to(b_k.dtype) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * V * K, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BV] + b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + if i_k == 0: + b_dv += tl.dot(b_A.to(b_do.dtype), b_do, allow_tf32=False) + b_do = (b_do * scale).to(b_do.dtype) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BT, BT] + b_dA += tl.dot(b_do, tl.trans(b_v), allow_tf32=False) + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + # [BT, BK] + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), (i_p * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + # [BT, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = exp(b_zp[None, :] - b_z) + # [BT, BK] + b_dq = b_dq * b_z + b_dk = b_dk * b_k + + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT,), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + # [BT, BT] + b_dA = tl.where(m_s, b_dA, 0.).to(b_k.dtype) + if i_k == 0: + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_V( + q, + k, + z, + dA, + dq, + dk, + T, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr +): + i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_zq = exp(b_zn[None, :] - b_z) + b_dq = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = exp(b_k - b_zn[None, :]).to(b_k.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dq += tl.dot(b_dA, b_kz, allow_tf32=False) + b_dq *= b_zq + + o_i = tl.arange(0, BC) + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_i * BC + m_dA = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + p_kj = tl.make_block_ptr(k + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i*BC+j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j, mask=m_dA, other=0) + # [BK,] + b_kj = tl.load(p_kj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_dq += tl.where(m_i, b_dA[:, None] * exp(b_kj[None, :] - b_z), 0.) + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*K, (T*K,), (1,), ((i_t * BT + i_i * BC + BC - 1) * K + i_k * BK,), (BK,), (0,)) + # [BK,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kz = exp(b_k - b_zn[None, :]) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_j * BC, i_i * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_qz = (b_q * exp(b_zn[None, :] - b_z)).to(b_q.dtype) + # [BC, BC] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BC, BK] + b_dk += tl.dot(tl.trans(b_dA), b_qz, allow_tf32=False) + b_dk *= b_kz + + o_dA = i_bh * T * BT + (i_t * BT + i_i * BC) * BT + i_i * BC + tl.arange(0, BC) + for j in range(0, BC): + p_qj = tl.make_block_ptr(q + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + p_zj = tl.make_block_ptr(z + i_bh * T*K, (T * K,), (1,), ((i_t * BT + i_i * BC + j) * K + i_k * BK,), (BK,), (0,)) + # [BC,] + b_dA = tl.load(dA + o_dA + j * BT, mask=(i_t * BT + i_i * BC + j < T), other=0) + # [BK,] + b_qj = tl.load(p_qj, boundary_check=(0,)).to(tl.float32) + b_zj = tl.load(p_zj, boundary_check=(0,)).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_dk += tl.where(m_i, b_dA[:, None] * b_qj[None, :] * exp(b_k - b_zj[None, :]), 0.) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_K( + v, + z, + do, + dA, + scale, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i, i_j = i_c // (NC * NC), (i_c % (NC * NC)) // NC, (i_c % (NC * NC)) % NC + n_bh = tl.num_programs(2) + + if i_i > i_j: + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT + i_j * BC), (BV, BC), (0, 1)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC) * V + i_v * BV,), (BV,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_dA = tl.make_block_ptr(dA+(i_bh+i_v*n_bh)*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_zn[None, :] - b_z) * scale).to(b_do.dtype) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = exp(b_v - b_zn[:, None]).to(b_v.dtype) + # [BC, BC] + b_dA = tl.dot(b_do, b_v, allow_tf32=False) + tl.store(p_dA, b_dA.to(dA.dtype.element_ty), boundary_check=(0, 1)) + elif i_i == i_j: + p_v = tl.make_block_ptr(v + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_j * BC) * V + i_v * BV,), (BV,), (0,)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) * scale + + o_i = tl.arange(0, BC) + o_A = (i_bh + i_v * n_bh) * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC + m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T + for j in range(0, BC): + # [BV,] + b_v = tl.load(p_v, boundary_check=(0,)).to(tl.float32) + # [BC,] + b_dA = tl.sum(b_do * exp(b_v[None, :] - b_z), 1) + b_dA = tl.where(o_i >= j, b_dA, 0) + tl.store(dA + o_A + j, b_dA.to(b_do.dtype), mask=m_A) + + p_v = tl.advance(p_v, (V,)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_K( + q, + k, + v, + z, + h, + A, + do, + dh, + dq, + dk, + dv, + dA, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_p = tl.maximum(i_t * BT - 1, 0) + n_bh = tl.num_programs(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_A = tl.make_block_ptr(A + (i_k*n_bh+i_bh) * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.dot((b_q * scale).to(b_q.dtype), tl.trans(b_k), allow_tf32=False) + b_A = tl.where(m_s, b_A, 0.) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_zp = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), (i_p * V + i_v * BV,), (BV,), (0,)) + p_zc = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,)) + p_h = tl.make_block_ptr(h + i_bh * NT*K*V + i_t * K*V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + i_bh * NT*K*V + i_t * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh) * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + # [BV,] + b_zp = tl.load(p_zp, boundary_check=(0,)) + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v = exp(b_v - b_zc[None, :]).to(b_v.dtype) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_z = exp(b_zp[None, :] - b_z) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * b_z * scale).to(b_do.dtype) + # [BK, BV] + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BT, BK] + b_dq += tl.dot(b_do, b_h, allow_tf32=False) + b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False) + # [BT, BV] + b_dv = b_v * tl.dot(b_k, b_dh, allow_tf32=False) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + p_dA = tl.make_block_ptr(dA + i_bh * T * BT, (T, BT, ), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BT] + b_dA = tl.load(p_dA, boundary_check=(0, 1)) + # [BT, BK] + b_dq += tl.dot(b_dA, b_k, allow_tf32=False) + b_dk += tl.dot(tl.trans(b_dA).to(b_k.dtype), b_q, allow_tf32=False) + + p_dq = tl.make_block_ptr(dq + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_intra_KV( + v, + z, + A, + do, + dv, + T, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BV: tl.constexpr, + NC: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*V, (T*V,), (1,), ((i_t * BT + i_i * BC + BC - 1) * V + i_v * BV,), (BV,), (0,)) + # [BV,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (BT, T), (1, BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_j * BC, i_v * BV), (BC, BV), (1, 0)) + # [BC, BV] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_do = (b_do * exp(b_zn[None, :] - b_z)).to(b_do.dtype) + # [BC, BC] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dv += tl.dot(b_A, b_do, allow_tf32=False) + b_dv *= exp(b_v - b_zn[None, :]) + + o_i = tl.arange(0, BC) + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + p_A = tl.make_block_ptr(A + i_bh * T * BT, (T * BT,), (1,), ((i_t * BT + i_i * BC + j) * BT + i_i * BC,), (BC,), (0,)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T * V,), (1,), ((i_t * BT + i_i * BC + j) * V + i_v * BV,), (BV,), (0,)) + # [BC,] + b_A = tl.load(p_A, boundary_check=(0,)) + # [BV,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_do = tl.load(p_do, boundary_check=(0,)) + # [BC, BV] + m_i = o_i[:, None] <= j + b_dv += tl.where(m_i, exp(b_v - b_z[None, :]) * b_A[:, None] * b_do[None, :], 0.) + p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT + i_i * BC, i_v * BV), (BC, BV), (1, 0)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_rcum_inter( + s, + z, + ss, + doo, + T, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + NT: tl.constexpr +): + i_m, i_bh = tl.program_id(0), tl.program_id(1) + + b_sp = tl.zeros([BS,], dtype=tl.float32) + b_zp = tl.full([BS,], float('inf'), dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_zc = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT) * S + i_m * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT, i_m * BS), (BT, BS), (1, 0)) + # [BS,] + b_zc = tl.load(p_zc, boundary_check=(0,)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + + b_doo = exp(b_s - b_zp[None, :]) * b_sp[None, :] + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + # [BS,] + b_sp = b_sp * exp(b_zc - b_zp) + tl.sum(b_ss * exp(b_zc[None, :] - b_z), 0) + b_zp = b_zc + + +@triton.jit(do_not_specialize=['T']) +def chunk_abc_bwd_kernel_rcum_intra( + s, + z, + ss, + doo, + T, + S: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + NC: tl.constexpr +): + i_s, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_t, i_i = i_c // NC, i_c % NC + + o_i = tl.arange(0, BC) + m_o = tl.full([BC, BC], 1., dtype=tl.float32) + + p_s = tl.make_block_ptr(s + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + p_zn = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + BC - 1) * S + i_s * BS,), (BS,), (0,)) + p_doo = tl.make_block_ptr(doo + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_i * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)) + # [BS,] + b_zn = tl.load(p_zn, boundary_check=(0,)) + + b_doo = tl.zeros([BC, BS], dtype=tl.float32) + for i_j in range(i_i + 1, NC): + p_z = tl.make_block_ptr(z + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T, S), (S, 1), (i_t * BT + i_j * BC, i_s * BS), (BC, BS), (1, 0)) + # [BC, BS] + b_z = tl.load(p_z, boundary_check=(0, 1)) + b_ss = tl.load(p_ss, boundary_check=(0, 1)) + # [BC, BS] + b_doo += b_ss * exp(b_zn[None, :] - b_z) + b_doo = exp(b_s - b_zn[None, :]) * tl.dot(m_o.to(b_s.dtype), b_doo.to(b_s.dtype), allow_tf32=False) + + for j in range(0, BC): + p_z = tl.make_block_ptr(z + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + p_ss = tl.make_block_ptr(ss + i_bh * T*S, (T*S,), (1,), ((i_t * BT + i_i * BC + j) * S + i_s * BS,), (BS,), (0,)) + # [BS,] + b_z = tl.load(p_z, boundary_check=(0,)) + b_ss = tl.load(p_ss, boundary_check=(0,)) + # [BC, BS] + m_i = o_i[:, None] <= j + b_doo += tl.where(m_i, exp(b_s - b_z[None, :]) * b_ss[None, :], 0.) + b_doo += tl.load(p_doo, boundary_check=(0, 1)) + tl.store(p_doo, b_doo.to(p_doo.dtype.element_ty), boundary_check=(0, 1)) + + +class ChunkABCFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, s, initial_state, output_final_state): + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = 64, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NV, NM = triton.cdiv(V, BV), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def fwd_pre(s, B, H, T, S): + # keep cummulative normalizer in fp32 + z = torch.empty_like(s, dtype=torch.float) + grid = (B * H,) + logcumsumexp_fwd_kernel[grid]( + s, z, + T=T, S=S + ) + return z + + def fwd_inner(q, k, v, z, B, H, T, K, V, BT, BK, BV, NT, normk=False, h0=None, ht=None): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + h = q.new_empty(B, H, NT * K, V) + grid = (NV, NK, B * H) + chunk_abc_fwd_kernel_h[grid]( + k, v, z, h, h0, ht, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + USE_INITIAL_STATE=h0 is not None, + STORE_FINAL_STATE=ht is not None, + num_warps=num_warps, + num_stages=num_stages + ) + return h + + final_state = None + if output_final_state: + final_state = (q.new_empty(B, H, K, M, dtype=torch.float), + q.new_empty(B, H, M, V, dtype=torch.float)) + + z = fwd_pre(s, B, H, T, M) + scale = K ** -0.5 + hk = fwd_inner( + q=q, k=k, v=s, z=z, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + normk=False, + h0=initial_state[0] if initial_state is not None else None, + ht=final_state[0] if final_state is not None else None + ) + ok1 = torch.empty_like(s) + Ak = q.new_empty(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_fwd_kernel_K[grid]( + q, k, z, hk, ok1, Ak, + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + ok0 = torch.empty_like(s) + grid = (NM, NT * NC, B * H) + chunk_abc_fwd_kernel_intra_K[grid]( + s, z, ok0, Ak, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ok = ok0.add_(ok1) + + scale = 1. + # p is kept in fp32 for safe softmax backward + p = softmax_fwd(ok, dtype=torch.float) + qv = p.to(q.dtype) + + scale = 1. + hv = fwd_inner( + q=qv, k=s, v=v, z=z, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + normk=True, + h0=initial_state[1] if initial_state is not None else None, + ht=final_state[1] if final_state is not None else None + ) + Av = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_fwd_kernel_intra_V[grid]( + qv, s, z, Av, + scale=scale, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + Av = Av.sum(0) + ov = torch.empty_like(v) + grid = (NV, NT, B * H) + chunk_abc_fwd_kernel_V[grid]( + qv, v, z, hv, ov, Av, + scale=scale, + T=T, + K=M, + V=V, + BT=BT, + BK=BM, + BV=BV, + NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v, s, z, ok, p, hk, hv, Av) + ctx.BT = BT + return ov, final_state + + @staticmethod + @input_guard + def backward(ctx, dov, dht=None): + q, k, v, s, z, ok, p, hk, hv, Av = ctx.saved_tensors + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + BT, BC = ctx.BT, 16 + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + BM = min(64, triton.next_power_of_2(M)) + NT, NC = triton.cdiv(T, BT), triton.cdiv(BT, BC) + NK, NM = triton.cdiv(K, BK), triton.cdiv(M, BM) + num_warps = 4 if BK == 64 else 2 + num_stages = 1 + + def bwd_inner(q, z, do, B, H, T, K, V, BT, BK, BV, NT, scale, normk=False): + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + dh = q.new_empty(B, H, NT * K, V) + grid = (NK, NV, B * H) + chunk_abc_bwd_kernel_dh[grid]( + q, z, do, dh, + scale=scale, + T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT, + NORMK=normk, + num_warps=num_warps, + num_stages=num_stages + ) + return dh + + def bwd_post(s, z, ss, B, H, T, S, BT, BC, BS, NT, NC, NS): + doo = torch.empty_like(s) + grid = (NS, B * H) + chunk_abc_bwd_kernel_rcum_inter[grid]( + s, z, ss, doo, + T=T, S=S, BT=BT, BS=BS, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + grid = (NS, NT * NC, B * H) + chunk_abc_bwd_kernel_rcum_intra[grid]( + s, z, ss, doo, + T=T, S=S, BT=BT, BC=BC, BS=BS, NC=NC, + num_warps=num_warps, + num_stages=num_stages + ) + return doo + + scale = 1. + qv = p.to(q.dtype) + dhv = bwd_inner( + qv, z, dov, + B=B, H=H, T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + scale=scale, + normk=True + ) + dp1 = torch.empty_like(p) + dsv1 = torch.empty_like(s, dtype=torch.float) + dv = v.new_empty(NM, *v.shape) + dAv = q.new_zeros(B, H, T, BT) + grid = (NM, NT, B * H) + chunk_abc_bwd_kernel_V[grid]( + s, v, z, hv, Av, dov, dhv, dp1, dsv1, dv, dAv, + scale=scale, + T=T, K=M, V=V, BT=BT, BK=BM, BV=BV, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + dv = dv.sum(0) + dp0 = torch.empty_like(p) + dsv0 = s.new_zeros(s.shape, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_V[grid]( + qv, s, z, dAv, dp0, dsv0, + T=T, K=M, BT=BT, BC=BC, BK=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dp = dp1.add_(dp0) + dsv = dsv1.add_(dsv0) + + # softmax gradient, equivalent to: + # dok = p * (dp - (p * dp).sum(-1, True)) + dok = softmax_bwd(p, dp, dtype=ok.dtype) + + scale = K ** -0.5 + dhk = bwd_inner( + q, z, dok, + B=B, H=H, T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + scale=scale, + normk=False + ) + dAk = q.new_zeros(NM, B, H, T, BT) + grid = (NM, NT * NC * NC, B * H) + chunk_abc_bwd_kernel_intra_K[grid]( + s, z, dok, dAk, + scale=scale, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + dAk = dAk.sum(0) + + Ak = q.new_zeros(NK, B, H, T, BT) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dsk1 = s.new_empty(NK, *s.shape, dtype=torch.float) + grid = (NK, NT, B * H) + chunk_abc_bwd_kernel_K[grid]( + q, k, s, z, hk, Ak, dok, dhk, dq, dk, dsk1, dAk, + scale=scale, + T=T, K=K, V=M, BT=BT, BK=BK, BV=BM, NT=NT, + num_warps=num_warps, + num_stages=num_stages + ) + Ak = Ak.sum(0) + dsk1 = dsk1.sum(0) + dsk0 = torch.empty_like(s, dtype=torch.float) + grid = (NM, NT * NC, B * H) + chunk_abc_bwd_kernel_intra_KV[grid]( + s, z, Ak, dok, dsk0, + T=T, V=M, BT=BT, BC=BC, BV=BM, NC=NC, + num_warps=2, + num_stages=num_stages + ) + ds = dsv.add_(dsk1.add_(dsk0)) + ds -= bwd_post(s, z, ok * dok + p * dp, B, H, T, M, BT, BC, BM, NT, NC, NM) + ds = ds.to(s.dtype) + return dq, dk, dv, ds, None, None + + +@torch.compiler.disable +def chunk_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + initial_state: Optional[Tuple[torch.Tensor]] = None, + output_final_state: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + s (torch.Tensor): + slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]` + initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]): + Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[B, H, K, M]` and `[B, H, M, V]` if `output_final_state=True` else `None`. + """ + if not head_first: + q, k, v, s = map(lambda x: x.transpose(1, 2), (q, k, v, s)) + o, final_state = ChunkABCFunction.apply(q, k, v, s, initial_state, output_final_state) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/abc/naive.py b/fla/ops/abc/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f25c40db73bcf33d1599761be0008cc5be7c59 --- /dev/null +++ b/fla/ops/abc/naive.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import repeat + + +def naive_recurrent_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[int] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: Optional[bool] = False +) -> torch.Tensor: + dtype = q.dtype + + NG = q.shape[1]//k.shape[1] + # [batch_size, n_heads, seq_len, n_slots] + if g is None: + z = s.float().logcumsumexp(2) + g = torch.cat((z[:, :, :1], z[:, :, :-1]), 2) - z + s = torch.exp(s - z) + q, k, v, s, g = map(lambda x: x.float(), (q, k, v, s, g)) + k, v, s, g = map(lambda x: repeat(x, 'b h t d -> b (h g) t d', g=NG), (k, v, s, g)) + if initial_state is not None: + initial_state = tuple(map(lambda x: repeat(x, 'b h k v -> b (h g) k v', g=NG), initial_state)) + + B, H, T, K, V, M = *q.shape, v.shape[-1], s.shape[-1] + + hk = torch.zeros(B, H, K, M, dtype=torch.float, device=q.device) + ok = torch.zeros_like(s) + + if scale is None: + scale = q.shape[-1] ** -0.5 + + final_state = None + if initial_state is not None: + hk += initial_state[0] + + for i in range(T): + q_i = q[:, :, i] * scale + k_i = k[:, :, i] + v_i = s[:, :, i] + g_i = g[:, :, i].exp() + hk = hk * g_i[..., None, :] + k_i[..., None] * v_i[..., None, :] + ok[:, :, i] = (q_i[..., None] * hk).sum(-2) + + qv = ok.softmax(-1) + hv = torch.zeros(B, H, M, V, dtype=torch.float, device=q.device) + ov = torch.zeros_like(v) + if initial_state is not None: + hv += initial_state[1] + + for i in range(T): + q_i = qv[:, :, i] + k_i = s[:, :, i] + v_i = v[:, :, i] + g_i = g[:, :, i].exp() + hv = hv * g_i[..., :, None] + k_i[..., None] * v_i[..., None, :] + ov[:, :, i] = (q_i[..., None] * hv).sum(-2) + + if output_final_state: + final_state = (hk.view(B, -1, NG, K, M)[:, :, 0], hv.view(B, -1, NG, M, V)[:, :, 0]) + return ov.to(dtype), final_state + + +def naive_cumsum_abc( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + s: torch.Tensor +) -> torch.Tensor: + """ + A simple implementation of vanilla ABC that is more aligned with the descriptions in the paper. + This is just for demonstration purposes, with no numerical stabilities guaranteed. + """ + + dtype = q.dtype + q, k, v, s = map(lambda x: x.float(), (q, k, v, s)) + + scale = q.shape[-1] ** -0.5 + # [batch_size, n_heads, seq_len, n_slots] + s = (s - s.max(2, True)[0]).exp() + z = s.cumsum(2) + # [batch_size, n_heads, seq_len, n_slots, d_head] + K = (s.unsqueeze(-1) * k.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + V = (s.unsqueeze(-1) * v.unsqueeze(-2)).cumsum(2) / z.unsqueeze(-1) + # [batch_size, n_heads, seq_len, n_slots] + p = torch.einsum('...d,...md->...m', q * scale, K).softmax(-1) + # [batch_size, n_heads, seq_len, d_head] + o = torch.einsum('...m,...md->...d', p, V) + return o.to(dtype), None diff --git a/fla/ops/based/__init__.py b/fla/ops/based/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f20b31ba0ea4c7d345761fbd6ab5f6ced5136236 --- /dev/null +++ b/fla/ops/based/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .fused_chunk import fused_chunk_based +from .parallel import parallel_based + +__all__ = [ + 'fused_chunk_based', + 'parallel_based' +] diff --git a/fla/ops/based/__pycache__/__init__.cpython-311.pyc b/fla/ops/based/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcda1a53fc17ab35f3f9a47c82bfd2712f5b6f4b Binary files /dev/null and b/fla/ops/based/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/based/__pycache__/fused_chunk.cpython-311.pyc b/fla/ops/based/__pycache__/fused_chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ab4bf07441289e444428829e8a5ef2cb8972f28 Binary files /dev/null and b/fla/ops/based/__pycache__/fused_chunk.cpython-311.pyc differ diff --git a/fla/ops/based/__pycache__/parallel.cpython-311.pyc b/fla/ops/based/__pycache__/parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3f2178274d03c92c4e10a32321a4211b3a93af6 Binary files /dev/null and b/fla/ops/based/__pycache__/parallel.cpython-311.pyc differ diff --git a/fla/ops/based/fused_chunk.py b/fla/ops/based/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..ff5db4fb73022c677662a4f7d29d6b2ec3015194 --- /dev/null +++ b/fla/ops/based/fused_chunk.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit(do_not_specialize=['T']) +def fused_chunk_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh + i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + p_z = z + (i_bh + i_k * B * H) * T + tl.arange(0, BT) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_0o = 0 + + for i in range(0, tl.cdiv(T, BT)): + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK*BK, BT] + b_k_2o = b_k[:, None, :] * b_k[None, :, :] + b_k_2o = tl.reshape(b_k_2o, [BK * BK, BT]).to(b_k.dtype) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BK] + b_q = (tl.load(p_q, boundary_check=(0, 1)) * scale).to(b_k.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_z = tl.zeros([BT], dtype=tl.float32) + + # interchunk + # zero-order + b_o += b_h_0o + b_z += k_0o + # first-order + b_o += tl.dot(b_q, b_h_1o.to(b_q.dtype), allow_tf32=False) + b_z += tl.sum(b_q * k_1o, axis=1) + # second-order + b_q_2o = b_q[:, :, None] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BT, BK * BK]).to(b_k.dtype) + b_o += tl.dot(b_q_2o, b_h_2o.to(b_q_2o.dtype), allow_tf32=False) * 0.5 + b_z += tl.sum(b_q_2o * k_2o, axis=1) * 0.5 + + # update running statistics + k_1o += tl.sum(b_k, axis=1)[None, :] + k_2o += tl.sum(b_k_2o, axis=1)[None, :] + k_0o += BT + + # intrachunk + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + # [TB, BV] + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=(i * BT + tl.arange(0, BT)) < T) + + # update hidden state + # [BK, BV] + b_h_2o = b_h_2o + tl.dot(b_k_2o.to(b_v.dtype), b_v, allow_tf32=False) + b_h_1o = b_h_1o + tl.dot(b_k, b_v, allow_tf32=False) + b_h_0o = b_h_0o + tl.sum(b_v, axis=0) + + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + p_z += BT + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_chunk_based_bwd_kernel( + # NV: number of split in the V dimension. NK: number of split in the K dimension + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, # K ** -0.5 + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + m_s = o_i[:, None] >= o_i[None, :] + + # [BV], zero-order taylor expansion + # b_h_0o = tl.zeros([BV], dtype=tl.float32) + # [BK, BV], first-order taylor expansion + b_h_1o = tl.zeros([BV, BK], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_h_2o = tl.zeros([BV, BK*BK], dtype=tl.float32) + + k_1o = tl.zeros([1, BK], dtype=tl.float32) + k_2o = tl.zeros([1, BK * BK], dtype=tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i * BT + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + # load tensors + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT) + i * BT) < T) + # [BV, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # inter-chunk + b_dq += tl.dot(b_do, (b_h_1o).to(b_do.dtype), allow_tf32=False) + if i_v == 0: + b_dq += b_dz[:, None] * k_1o + b_dq_2o = tl.dot(b_do, (b_h_2o).to(b_do.dtype), allow_tf32=False) * 0.5 + if i_v == 0: + b_dq_2o += (b_dz[:, None] * k_2o) * 0.5 + b_dq_2o = tl.reshape(b_dq_2o, [BT, BK, BK]) + b_dq += tl.sum(b_dq_2o * b_q[:, :, None], axis=1) + b_dq += tl.sum(b_dq_2o * b_q[:, None, :], axis=2) + b_dq *= scale + + # intra-chunk + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_q.dtype), b_k, allow_tf32=False) + + # store + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # update hidden state + # [BT, BK*BK] + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + # [BV, BK*BK] + b_h_2o = b_h_2o + tl.dot(b_v, b_k_2o.to(b_v.dtype), allow_tf32=False) + # [BV, BK] + b_h_1o = b_h_1o + tl.dot(b_v, b_k, allow_tf32=False) + + if i_v == 0: + # update running statistics + k_1o += tl.sum(b_k, axis=0)[None, :] + k_2o += tl.sum(b_k_2o, axis=0)[None, :] + + tl.debug_barrier() + b_h_1o = None + b_h_2o = None + + # [BK, BV], first-order taylor expansion + b_dh_1o = tl.zeros([BK, BV], dtype=tl.float32) + # [BK, BK, BV] second-order taylor expansion + b_dh_2o = tl.zeros([BK*BK, BV], dtype=tl.float32) + b_dh_0o = tl.zeros([BV], dtype=tl.float32) + m_s = tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :] + + dq_1o = tl.zeros([1, BK], dtype=tl.float32) + dq_2o = tl.zeros([BK * BK, 1], dtype=tl.float32) + + for i in range(tl.cdiv(T, BT) * BT - BT, -BT, -BT): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (i, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (i, i_v*BV), (BT, BV), (1, 0)) + p_dz = dz + (i_bh) * T + tl.arange(0, BT) + i + + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(tl.arange(0, BT)+i) < T) + b_q = (b_q * scale).to(b_k.dtype) + + # intra chunk + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + b_ds = tl.where(m_s, b_ds, 0) + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + b_ds *= (1+b_s) + + b_dk += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_q), allow_tf32=False) + b_dv += tl.dot(b_s2.to(b_do.dtype), b_do, allow_tf32=False) + + # inter chunk + b_k_2o = b_k[:, :, None] * b_k[:, None, :] + b_k_2o = tl.reshape(b_k_2o, [BT, BK * BK]).to(b_k.dtype) + + b_dv += tl.dot(b_k, b_dh_1o.to(b_k.dtype), allow_tf32=False) + b_dv += tl.dot(b_k_2o, b_dh_2o.to(b_k.dtype), allow_tf32=False) + b_dv += b_dh_0o + + b_dk += tl.dot(b_v, tl.trans(b_dh_1o).to(b_k.dtype), allow_tf32=False) + + if i_v == 0: + b_dk += dq_1o + + b_dk_2o = tl.dot(b_dh_2o.to(b_k.dtype), tl.trans(b_v), allow_tf32=False) + if i_v == 0: + b_dk_2o += dq_2o + b_dk_2o = tl.reshape(b_dk_2o, [BK, BK, BT]) + b_k_fp32 = tl.trans(b_k.to(tl.float32)) + b_dk2 = tl.sum(b_dk_2o * b_k_fp32[:, None, :], axis=0) + b_dk2 += tl.sum(b_dk_2o * b_k_fp32[None, :, :], axis=1) + b_dk += tl.trans(b_dk2) + + # hidden state update + b_dh_0o += tl.sum(b_do, axis=0) + b_dh_1o = b_dh_1o + tl.dot(b_q, b_do, allow_tf32=False) + b_q_2o = b_q[None, :, :] * b_q[:, None, :] + b_q_2o = tl.reshape(b_q_2o, [BK * BK, BT]).to(b_k.dtype) + b_dh_2o = b_dh_2o + tl.dot(b_q_2o, b_do, allow_tf32=False) * 0.5 + + if i_v == 0: + dq_1o += (tl.sum(b_dz[None, :] * b_q, axis=1))[None, :] + dq_2o += (tl.sum(b_dz[None, :] * b_q_2o, axis=1) * 0.5)[:, None] + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale=1): + B, H, T, K, V = *k.shape, v.shape[-1] + + scale = scale + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + num_warps = 4 + + # the norm of o might explode, so we need to use float32 here + o = q.new_empty(NK, B, H, T, V, dtype=torch.float32) + z = q.new_empty(NK, B, H, T, dtype=torch.float32) + + grid = (NV, NK, B * H) + fused_chunk_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + ) + o = o.sum(0) + z = z.sum(0) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.to(q.dtype), z.to(z.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 16 + BK, BV = min(K, 16), min(V, 32) + BK, BV = max(BK, 16), max(BV, 16) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_stages = 1 + num_warps = 4 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + T=T, B=B, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None + + +def fused_chunk_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = True +): + assert q.shape[-1] <= 16, 'only support feature dimension up to 16.' + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = FusedChunkBasedFunction.apply(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla/ops/based/naive.py b/fla/ops/based/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..4de614137ed28567ebb1df39c0892f498b91fb5a --- /dev/null +++ b/fla/ops/based/naive.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +import torch +from einops import rearrange + + +def naive_parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True +): + if scale is None: + scale = q.shape[-1] ** -0.5 + q = q * scale + attn = q @ k.transpose(-2, -1) + attn = 1 + attn + 1/2 * (attn ** 2) + attn.masked_fill_(~torch.tril(torch.ones( + q.shape[-2], q.shape[-2], dtype=torch.bool, device=q.device)), 0) + o = attn @ v + if use_norm: + z = attn.sum(-1) + return o / (z[..., None] + 1e-6) + else: + return o + + +def naive_chunk_based(q, k, v, chunk_size=256): + q = q * (q.shape[-1] ** -0.5) + # compute normalizer. + k_cumsum = torch.cumsum(k, dim=-2) + kk_cumsum = torch.cumsum(k.unsqueeze(-1) * k.unsqueeze(-2), dim=-3) + # first + z = (q * k_cumsum).sum(-1) + # second order + z += (q.unsqueeze(-1) * q.unsqueeze(-2) * kk_cumsum).sum((-1, -2)) * 0.5 + # zero-th order + z += (torch.arange(0, q.shape[-2]).to(z.device) * 1.0 + 1.0)[None, None, :] + + # compute o + # constant term + _o = v.cumsum(-2) + + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) + + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + + intra_chunk_attn = q @ k.transpose(-2, -1) + intra_chunk_attn = intra_chunk_attn + 1/2 * (intra_chunk_attn ** 2) + intra_chunk_attn.masked_fill_(~torch.tril(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device)), 0) + o = intra_chunk_attn @ v + + # quadractic term + kv = torch.einsum('b h n c x, b h n c y, b h n c z -> b h n x y z', k, k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + + o += 0.5 * torch.einsum('b h n x y z, b h n c x, b h n c y -> b h n c z', kv, q, q) + + # linear term + kv = torch.einsum('b h n c x, b h n c y -> b h n x y', k, v) + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + o += torch.einsum('b h n x y, b h n c x -> b h n c y', kv, q) + + o = rearrange(o, 'b h n c d -> b h (n c) d') + o = o + _o + return o / (z[..., None] + 1e-6) diff --git a/fla/ops/based/parallel.py b/fla/ops/based/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d4621ea5838bc410a33b1b0f0af40b3c322f02b5 --- /dev/null +++ b/fla/ops/based/parallel.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + +# Based: An Educational and Effective Sequence Mixer +# https://hazyresearch.stanford.edu/blog/2023-12-11-zoology2-based + + +@triton.jit(do_not_specialize=['T']) +def parallel_based_fwd_kernel( + q, + k, + v, + o, + z, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + # i_c: chunk index. used for sequence parallelism + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % (NV) + + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BTS, BV), (1, 0)) + + # [BQ, BD] block Q, in the shared memory throughout the whole kernel + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BTL, BV], dtype=tl.float32) + b_z = tl.zeros([BTL], dtype=tl.float32) + + # Q block and K block have no overlap + # no need for mask, thereby saving flops + for _ in range(0, i_c * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_s = tl.dot(b_q, (b_k), allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_z += tl.sum(b_s, axis=1) + + # [BQ, BD] + b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False) + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + + # # rescale interchunk output + tl.debug_barrier() + o_q = tl.arange(0, BTL) + # # sync threads, easy for compiler to optimize + # tl.debug_barrier() + + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BK, BTS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BTS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_z += tl.sum(b_s, axis=1) + # [BTL, BV] + b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + + p_k = tl.advance(p_k, (0, BTS)) + p_v = tl.advance(p_v, (BTS, 0)) + o_k += BTS + + p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_z, b_z.to(p_z.dtype.element_ty), mask=((i_c * BTL + tl.arange(0, BTL)) < T)) + + +@triton.jit +def _parallel_based_bwd_dq( + i_bh, + i_c, + i_k, + i_v, + q, + k, + v, + do, + dz, + dq, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + p_q = tl.make_block_ptr(q + (i_bh) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dq = tl.zeros([BTL, BK], dtype=tl.float32) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, 0), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL) + b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T) + + for _ in range(0, i_c * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + # [BQ, BD] + b_dq += tl.dot((b_ds * (1 + b_s)).to(b_v.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + + b_dq *= scale + o_q = tl.arange(0, BTL) + o_k = tl.arange(0, BTS) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1)) + # Q block and K block have overlap. masks required + for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS): + # [BTS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BTS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BTL, BTS] + m_s = o_q[:, None] >= o_k[None, :] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[:, None] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BTL, BK] + b_dq += tl.dot((b_ds + b_ds * b_s).to(b_k.dtype), b_k, allow_tf32=False) + p_k = tl.advance(p_k, (BTS, 0)) + p_v = tl.advance(p_v, (0, BTS)) + o_k += BTS + p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit +def _parallel_based_bwd_dkv( + i_bh, + i_c, + i_k, + i_v, + q, + k, + v, + do, + dz, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, +): + # compute dk dv + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_c * BTL, i_v * BV), (BTL, BV), (1, 0)) + b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(p_v, boundary_check=(0, 1)) + b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros([BTL, BV], dtype=tl.float32) + + for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS] + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * scale # [BTL, BTS] + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale + if i_v == 0: + b_ds += b_dz[None, :] * scale + else: + b_ds = b_ds + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + + tl.debug_barrier() + o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL) + for i in range(i_c*BTL, (i_c+1)*BTL, BTS): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i), (BK, BTS), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (V, T), (1, V), (i_v * BV, i), (BV, BTS), (0, 1)) + p_dz = dz + i_bh * T + i + tl.arange(0, BTS) + b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ] + b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) + b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T) + # [BK, BQ] + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale + b_s2 = 1 + b_s + 0.5 * b_s * b_s + b_s = tl.where(m_s, b_s, 0) + b_s2 = tl.where(m_s, b_s2, 0) + + b_ds = tl.dot(b_v, b_do, allow_tf32=False) + if i_v == 0: + b_ds += b_dz[None, :] + else: + b_ds = b_ds + b_ds = tl.where(m_s, b_ds, 0) * scale + # [BK, BD] + b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False) + b_dk += tl.dot((b_ds + b_ds * b_s).to(b_q.dtype), tl.trans(b_q), allow_tf32=False) + o_q += BTS + + p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * T*K, (T, K), (K, 1), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * T*V, (T, V), (V, 1), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + return + + +@triton.jit(do_not_specialize=['T']) +def parallel_based_bwd_kernel( + q, + k, + v, + do, + dz, + dq, + dk, + dv, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BTL: tl.constexpr, + BTS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, +): + i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + NV = tl.cdiv(V, BV) + i_k = i_kv // (NV) + i_v = i_kv % NV + _parallel_based_bwd_dq( + i_bh, i_c, i_k, i_v, + q, k, v, do, dz, dq, + scale, T, B, H, BTL, BTS, BK, BV, K, V + ) + tl.debug_barrier() + _parallel_based_bwd_dkv( + i_bh, i_c, i_k, i_v, + q, k, v, do, dz, dk, dv, + scale, T, B, H, BTL, BTS, BK, BV, K, V + ) + + +class ParallelBasedFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale): + BTL, BTS = 128, 32 + assert BTL % BTS == 0 + # assert q.shape[-1] % 16 == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not." + + o = torch.empty(NK, B, H, T, V, device=q.device) + z = torch.empty(NK, B, H, T, device=q.device) + parallel_based_fwd_kernel[grid]( + q, k, v, o, z, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + ctx.save_for_backward(q, k, v) + ctx.scale = scale + return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dz): + q, k, v = ctx.saved_tensors + scale = ctx.scale + BTL, BTS = 64, 32 + assert BTL % BTS == 0 + BK = min(128, triton.next_power_of_2(k.shape[-1])) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + BK, BV = max(BK, 16), max(BV, 16) + B, H, T, K, V = *k.shape, v.shape[-1] + num_stages = 2 + num_warps = 4 + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + grid = (NK * NV, triton.cdiv(T, BTL), B * H) + + assert NK == 1, "will encounter some synchronization issue if not" + + dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device) + dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device) + + parallel_based_bwd_kernel[grid]( + q, k, v, do, dz, dq, dk, dv, + scale, + B=B, + H=H, + T=T, + K=K, + V=V, + BTL=BTL, + BTS=BTS, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages + ) + + return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None + + +triton_parallel_based = ParallelBasedFunction.apply + + +def parallel_based( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + use_norm: bool = True, + head_first: bool = True +): + assert q.shape[-1] <= 128, "only support feature dim up to 128" + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, z = triton_parallel_based(q, k, v, scale) + if use_norm: + o = o / (z[..., None] + 1e-6) + if not head_first: + o = o.transpose(1, 2) + return o.to(q.dtype) diff --git a/fla/ops/common/__init__.py b/fla/ops/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/fla/ops/common/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/fla/ops/common/__pycache__/__init__.cpython-311.pyc b/fla/ops/common/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc3da04704eba3d133ea826339c239b159aa1d14 Binary files /dev/null and b/fla/ops/common/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/common/__pycache__/chunk_delta_h.cpython-311.pyc b/fla/ops/common/__pycache__/chunk_delta_h.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290754210d95b78e299564de63d6abe61982e3ea Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_delta_h.cpython-311.pyc differ diff --git a/fla/ops/common/__pycache__/chunk_h.cpython-311.pyc b/fla/ops/common/__pycache__/chunk_h.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db8724069812fed257d7799f6650113387eb0756 Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_h.cpython-311.pyc differ diff --git a/fla/ops/common/__pycache__/chunk_o.cpython-311.pyc b/fla/ops/common/__pycache__/chunk_o.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2098d178e16e0990426090dbda24ba254ae0f76 Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_o.cpython-311.pyc differ diff --git a/fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-311.pyc b/fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4de3555ecb54fc17be989ceb186d7fb4d7b712af Binary files /dev/null and b/fla/ops/common/__pycache__/chunk_scaled_dot_kkt.cpython-311.pyc differ diff --git a/fla/ops/common/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/common/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa58f8ab51d8a2a75e1f9e3d66b258952f0116a1 Binary files /dev/null and b/fla/ops/common/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/common/__pycache__/utils.cpython-311.pyc b/fla/ops/common/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bcd06a13201bb0de33db2609bc615bc67efd20c Binary files /dev/null and b/fla/ops/common/__pycache__/utils.cpython-311.pyc differ diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py new file mode 100644 index 0000000000000000000000000000000000000000..f5ed788cfa86c42bb9e04b90ae9c659321494bba --- /dev/null +++ b/fla/ops/common/chunk_delta_h.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_fwd_kernel_h( + k, + v, + d, + v_new, + g, + h, + h0, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NT: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_hc = tl.zeros([BK, BV], dtype=tl.float32) + if USE_G: + last_idx = min((i_t + 1) * BT, T) - 1 + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + else: + b_g_last = None + last_idx = None + # since we need to make all DK in the SRAM. we face serve SRAM memory burden. By subchunking we allievate such burden + for i_c in range(tl.cdiv(min(BT, T - i_t * BT), BC)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+i_nh*T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None + else: + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT+i_c*BC, i_v * BV), (BC, BV), (1, 0)) + p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT+i_c*BC, ), (BC,), (0,)) if USE_G else None + b_g = tl.load(p_g, boundary_check=(0, )) if USE_G else None + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_k = (b_k * exp(b_g_last - b_g)[None, :]).to(b_k.dtype) if USE_G else b_k + # [BC, BK] + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_d = (b_d * exp(b_g)[:, None]).to(b_d.dtype) if USE_G else b_d + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v2 = b_v - tl.dot(b_d, b_h.to(b_d.dtype)) + # [BK, BV] + tl.store(p_v_new, b_v2.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + b_hc += tl.dot(b_k, b_v2.to(b_k.dtype), allow_tf32=False) + b_h *= exp(b_g_last) if USE_G else 1 + b_h += b_hc + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_INITIAL_STATE': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV', 'USE_G'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_gated_delta_rule_bwd_kernel_dhu( + q, + k, + d, + g, + dht, + dh0, + do, + dh, + dv, + dv2, + offsets, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + b_dh_tmp = tl.zeros([BK, BV], dtype=tl.float32) + if USE_G: + last_idx = min((i_t + 1) * BT, T) - 1 + if HEAD_FIRST: + bg_last = tl.load(g + i_nh * T + last_idx) + else: + bg_last = tl.load(g + (bos + last_idx) * H + i_h) + else: + bg_last = None + last_idx = None + for i_c in range(tl.cdiv(BT, BC) - 1, -1, -1): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_k = tl.make_block_ptr(k + i_nh * T*K, (T, K), (K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_d = tl.make_block_ptr(d + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_g = tl.make_block_ptr(g + i_nh * T, (T,), (1,), (i_t * BT + i_c * BC,), (BC,), (0,)) if USE_G else None + p_dv2 = tl.make_block_ptr(dv2 + i_nh * T*V, (T, V), (V, 1), (i_t * BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_k = tl.make_block_ptr(k+(bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_c * BC, i_k * BK), (BC, BK), (1, 0)) + p_d = tl.make_block_ptr(d+(bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_c * BC), (BK, BC), (0, 1)) + p_dv = tl.make_block_ptr(dv+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_do = tl.make_block_ptr(do+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_g = tl.make_block_ptr(g+bos*H+i_h, (T,), (H,), (i_t*BT + i_c * BC,), (BC,), (0,)) if USE_G else None + p_dv2 = tl.make_block_ptr(dv2+(bos*H+i_h)*V, (T, V), (H*V, 1), (i_t*BT + i_c * BC, i_v * BV), (BC, BV), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0,)) if USE_G else None + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale * exp(b_g)[None, :]).to(b_q.dtype) if USE_G else (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_d = tl.load(p_d, boundary_check=(0, 1)) + b_k = (b_k * exp(bg_last - b_g)[:, None]).to(b_k.dtype) if USE_G else b_k + b_d = (b_d * exp(b_g)[None, :]).to(b_d.dtype) if USE_G else b_d + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dv2 = b_dv + tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + tl.store(p_dv2, b_dv2.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BV] + b_dh_tmp += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False) + b_dh_tmp -= tl.dot(b_d, b_dv2.to(b_q.dtype), allow_tf32=False) + b_dh *= exp(bg_last) if USE_G else 1 + b_dh += b_dh_tmp + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, u.shape[-1] + else: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension larger than 256." + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + # A100 + elif check_shared_mem('ampere', k.device.index): + BV = 32 + BC = 64 + else: + BV = 32 + BC = 32 if K <= 128 else 16 + BC = min(BT, BC) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + h = k.new_empty(B, H, NT, K, V) + else: + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) + grid = (NK, NV, N * H) + + chunk_gated_delta_rule_fwd_kernel_h[grid]( + k=k, + v=u, + d=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + NT=NT, + HEAD_FIRST=head_first + ) + return h, v_new, final_state + + +def chunk_gated_delta_rule_bwd_dhu( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + g: torch.Tensor, + h0: torch.Tensor, + dht: Optional[torch.Tensor], + do: torch.Tensor, + dv: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *q.shape, do.shape[-1] + else: + B, T, H, K, V = *q.shape, do.shape[-1] + BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(offsets) - 1, len(indices), prepare_chunk_offsets(offsets, BT) + + BK = triton.next_power_of_2(K) + assert BK <= 256, "current kernel does not support head dimension being larger than 256." + + # H100 + if check_shared_mem('hopper', q.device.index): + BV = 64 + BC = 64 if K <= 128 else 32 + # A100 + elif check_shared_mem('ampere', q.device.index): + BV = 32 + BC = 64 if K <= 128 else 32 + else: + BV = 32 if K <= 128 else 16 + BC = 16 + + BC = min(BT, BC) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization' + + if head_first: + dh = q.new_empty(B, H, NT, K, V) + else: + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + grid = (NK, NV, N * H) + chunk_gated_delta_rule_bwd_kernel_dhu[grid]( + q=q, + k=k, + d=w, + g=g, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + offsets=offsets, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BC=BC, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dh, dh0, dv2 diff --git a/fla/ops/common/chunk_h.py b/fla/ops/common/chunk_h.py new file mode 100644 index 0000000000000000000000000000000000000000..0aa5a7a93b9741968fa03ab630eb8aba062ccc5f --- /dev/null +++ b/fla/ops/common/chunk_h.py @@ -0,0 +1,422 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_offsets +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem + +BKV_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + offsets, + split_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = i_n * NS + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT): + i_s = i_t // (BS // BT) + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + o_h = (i_nh * NS + i_s).to(tl.int64) * K*V + p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t % (BS // BT) == 0: + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + last_idx = min((i_t + 1) * BT, T) - 1 + + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_h *= exp(b_g_last) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h += tl.dot(b_k, b_v) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + offsets, + split_offsets, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + NS = tl.cdiv(T, BS) + boh = i_n * NS + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + + for i_t in range(NT - 1, -1, -1): + i_s = i_t // (BS // BT) + if HEAD_FIRST: + o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V + p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V + p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t % (BS // BT) == 0: + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + # [BK, BT] + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + if HEAD_FIRST: + p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + b_g_last = tl.load(g + i_bg * T + last_idx) + else: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + + b_dh *= exp(b_g_last) + + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + b_dh += tl.dot(b_q, b_do) + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + split_size: Optional[int] = None, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) + assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + split_offsets, N, NS = None, B, triton.cdiv(T, BS) + else: + split_offsets = prepare_chunk_offsets(offsets, BS) + N, NS = len(offsets) - 1, split_offsets[-1] + + if head_first: + h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h[grid]( + k=k, + v=v, + h=h, + g=g, + gk=gk, + gv=gv, + h0=h0, + ht=ht, + offsets=offsets, + split_offsets=split_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + return h, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + offsets: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + split_size: Optional[int] = None, + states_in_fp32: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T))) + assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}" + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if offsets is None: + split_offsets, N, NS = None, B, triton.cdiv(T, BS) + else: + split_offsets = prepare_chunk_offsets(offsets, BS) + N, NS = len(offsets) - 1, split_offsets[-1] + NG = HQ // H + + if head_first: + dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_bwd_kernel_dh[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + offsets=offsets, + split_offsets=split_offsets, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + return dh, dh0 diff --git a/fla/ops/common/chunk_h_parallel.py b/fla/ops/common/chunk_h_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..51083eda8efbe012432ebf4a08fb34954a0dfd89 --- /dev/null +++ b/fla/ops/common/chunk_h_parallel.py @@ -0,0 +1,650 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +""" +Fully parallelized state passing. +""" + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_parallel( + k, + v, + h, + g, + gk, + gv, + h0, + ht, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NV = tl.cdiv(V, BV) + # i_b: batch index + # i_h: head index + # i_n: sequence index + # i_t: chunk index within current sequence + # i_tg: (global) chunk index across all sequences + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_b * T, i_b * T + T + NT = tl.cdiv(T, BT) + i_n, i_tg = i_b, i_b * NT + i_t + i_nh = i_n * H + i_h + + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t == 0: + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + else: + b_h = tl.zeros([BK, BV], dtype=tl.float32) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_bh * T + last_idx) + p_g = g + i_bh * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + i_bh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + i_bh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h = tl.dot(b_k, b_v) + if i_t < NT - 1: + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_bh * NT + i_t + 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((i_tg + 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_reduction( + h, + g, + gk, + gv, + kvt, + ht, + offsets, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT): + if HEAD_FIRST: + p_h = tl.make_block_ptr(h + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_h = tl.make_block_ptr(h + ((boh + i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + if i_t > 0: + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_h *= exp(b_g_last) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + if STORE_FINAL_STATE: + p_kvt = tl.make_block_ptr(kvt + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_kvt, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_parallel( + q, + g, + gk, + gv, + do, + dh, + dht, + dh0, + offsets, + indices, + scale, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + NV = tl.cdiv(V, BV) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_hq, i_bg = i_bh // HQ, i_bh % HQ, i_bh // NG + i_h = i_hq // NG + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + bos, eos = i_b * T, i_b * T + T + NT = tl.cdiv(T, BT) + i_n, i_tg = i_b, i_b * NT + i_t + i_nh = i_n * HQ + i_hq + + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + if i_t == NT - 1: + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + else: + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + if USE_G: + if HEAD_FIRST: + p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_dh = tl.dot(b_q, b_do) + if i_t > 0: + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_bh * NT + i_t - 1) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((i_tg - 1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for BV in [32, 64, 128] + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'] +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_reduction( + g, + gk, + gv, + dh, + doq0, + dh0, + offsets, + chunk_offsets, + T, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_bg = i_nh // NG + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_h = i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_t in range(NT - 1, -1, -1): + if HEAD_FIRST: + p_dh = tl.make_block_ptr(dh + (i_nh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + else: + p_dh = tl.make_block_ptr(dh + ((boh+i_t) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dh, boundary_check=(0, 1)).to(tl.float32) + if i_t < NT - 1: + tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_bg * T + last_idx) + else: + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_dh *= exp(b_g_last) + + if USE_GK: + if HEAD_FIRST: + p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + if HEAD_FIRST: + p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + if STORE_INITIAL_STATE_GRADIENT: + p_doq0 = tl.make_block_ptr(doq0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_doq0, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + states_in_fp32: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + if indices is None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + N, NT = len(offsets) - 1, len(indices) + chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + + h = k.new_empty(B, H, NT, K, V, dtype=torch.float) if head_first else k.new_empty(B, NT, H, K, V, dtype=torch.float) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * H) + chunk_fwd_kernel_h_parallel[grid]( + k=k, + v=v, + h=h, + g=g, + gk=gk, + gv=gv, + h0=h0, + ht=ht, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + kvt, ht = ht, (torch.empty_like(ht) if output_final_state else None) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h_reduction[grid]( + h=h, + g=g, + gk=gk, + gv=gv, + kvt=kvt, + ht=ht, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + h = h.to(k.dtype) if not states_in_fp32 else h + return h, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + states_in_fp32: bool = False, + offsets: Optional[torch.Tensor] = None, + indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # N: the actual number of sequences in the batch with either equal or variable lengths + # NG: number of groups in GQA + if offsets is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + if indices is None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], BT).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + N, NT = len(offsets) - 1, len(indices) + chunk_offsets = torch.cat([offsets.new_tensor([0]), triton.cdiv(offsets[1:] - offsets[:-1], BT)]).cumsum(-1) + NG = HQ // H + + if head_first: + dh = k.new_empty(B, HQ, NT, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + else: + dh = k.new_empty(B, NT, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + def grid(meta): return (triton.cdiv(K, meta['BK']) * triton.cdiv(V, meta['BV']), NT, B * HQ) + chunk_bwd_kernel_dh_parallel[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dh=dh, + dht=dht, + dh0=dh0, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + + doq0, dh0 = dh0, (torch.empty_like(dh0) if dh0 is not None else None) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ) + chunk_bwd_kernel_dh_reduction[grid]( + g=g, + gk=gk, + gv=gv, + dh=dh, + doq0=doq0, + dh0=dh0, + offsets=offsets, + chunk_offsets=chunk_offsets, + T=T, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + dh = dh.to(q.dtype) if not states_in_fp32 else dh + return dh, dh0 diff --git a/fla/ops/common/chunk_h_split.py b/fla/ops/common/chunk_h_split.py new file mode 100644 index 0000000000000000000000000000000000000000..cc017fb6a6d058180ee651511d4105d914594494 --- /dev/null +++ b/fla/ops/common/chunk_h_split.py @@ -0,0 +1,677 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_split( + k, + v, + g, + gk, + gv, + hs, + hr, + h0, + ht, + offsets, + split_indices, + T, + S: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # handle one split at a time + # i_h: head index + # i_n: sequence index + # i_s: local split index inside a sequence + i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_ss, i_h = i_sh // H, i_sh % H + if USE_OFFSETS: + i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + else: + NS = tl.cdiv(T, S) + i_n, i_s = i_ss // NS, i_ss % NS + bos, eos = i_n * T, i_n * T + T + i_nh = i_n * H + i_h + + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # for the first split, we directly store the state as the final result + if i_s == 0: + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + p_hr = tl.make_block_ptr(hr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1)) + for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + last_idx = min(i_t * BT + BT, T) - 1 + + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h + b_h *= exp(b_g_last) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_k = (b_k * exp(b_gk_last[:, None] - b_gk)).to(b_k.dtype) + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_v = (b_v * exp(b_gv_last[None, :] - b_gv)).to(b_v.dtype) + + b_h += tl.dot(b_k, b_v) + + # if there are more than one splits, we store the result to (unreduced) hs + # otherwise, we store the result to ht as the final state + if NS > 1: + p_hs = tl.make_block_ptr(hs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_hs, b_h.to(p_hs.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_h_reduction( + g, + gk, + gv, + hs, + hr, + ht, + offsets, + split_offsets, + T, + S: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NS = tl.cdiv(T, S) + boh = i_n * NS + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + # skip the first split + for i_s in range(1, NS): + p_hs = tl.make_block_ptr(hs + ((boh + i_s-1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_hr = tl.make_block_ptr(hr + ((boh + i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_hr, b_h.to(p_hr.dtype.element_ty), boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(i_s * S, BT), tl.cdiv(min(i_s * S + S, T), BT)): + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_nh * T + last_idx) + else: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + b_h *= exp(b_g_last) + + # vector decay, h = Diag(gk) @ h + if USE_GK: + if HEAD_FIRST: + p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_h *= exp(b_gk_last)[:, None] + + # vector decay, h = h @ Diag(gv) + if USE_GV: + if HEAD_FIRST: + p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_h *= exp(b_gv_last)[None, :] + + if NS > 1: + if STORE_FINAL_STATE: + p_hs = tl.make_block_ptr(hs + ((boh + NS-1) * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_ht = tl.make_block_ptr(ht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h += tl.load(p_hs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_split( + q, + g, + gk, + gv, + do, + dht, + dhs, + dhr, + dh0, + offsets, + split_indices, + scale, + T, + S: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # handle one split at a time + # i_h: head index + # i_n: sequence index + # i_s: local split index inside a sequence + i_k, i_v, i_sh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_ss, i_hq = i_sh // HQ, i_sh % HQ + if USE_OFFSETS: + i_n, i_s = tl.load(split_indices + i_ss * 2).to(tl.int32), tl.load(split_indices + i_ss * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + else: + NS = tl.cdiv(T, S) + i_n, i_s = i_ss // NS, i_ss % NS + bos, eos = i_n * T, i_n * T + T + i_nh = i_n * HQ + i_hq + i_ng, i_h = i_nh // NG, i_hq // NG + + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if i_s == NS - 1: + if USE_FINAL_STATE_GRADIENT: + p_dht = tl.make_block_ptr(dht + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dht, boundary_check=(0, 1)).to(tl.float32) + p_dhr = tl.make_block_ptr(dhr + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1): + if HEAD_FIRST: + p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + last_idx = min(i_t * BT + BT, T) - 1 + if USE_G: + if HEAD_FIRST: + p_g = g + i_ng * T + i_t * BT + tl.arange(0, BT) + p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT) + b_g_last = tl.load(g + i_ng * T + last_idx) + else: + p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.) + b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype) + b_dh *= exp(b_g_last) + + if USE_GK: + if HEAD_FIRST: + p_gk = tl.make_block_ptr(gk + i_ng * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_q = (b_q * exp(b_gk)).to(b_q.dtype) + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + if HEAD_FIRST: + p_gv = tl.make_block_ptr(gv + i_ng * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv = tl.load(p_gv, boundary_check=(0, 1)) + b_do = (b_do * exp(b_gv)).to(b_do.dtype) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + b_dh += tl.dot(b_q, b_do) + + if NS > 1: + p_dhs = tl.make_block_ptr(dhs + i_sh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dhs, b_dh.to(p_dhs.dtype.element_ty), boundary_check=(0, 1)) + elif STORE_INITIAL_STATE_GRADIENT: + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [32, 64] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['BT', 'USE_G', 'USE_GK', 'USE_GV'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dh_reduction( + g, + gk, + gv, + dhs, + dhr, + dh0, + offsets, + split_offsets, + T, + S: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NG: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hq = i_nh // HQ, i_nh % HQ + i_ng, i_h = i_nh // NG, i_hq // NG + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NS = tl.cdiv(T, S) + boh = tl.load(split_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NS = tl.cdiv(T, S) + boh = i_n * NS + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + for i_s in range(NS - 2, -1, -1): + p_dhs = tl.make_block_ptr(dhs + ((boh+i_s+1) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dhr = tl.make_block_ptr(dhr + ((boh+i_s) * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dhr, b_dh.to(p_dhr.dtype.element_ty), boundary_check=(0, 1)) + + for i_t in range(tl.cdiv(min(i_s * S + S, T), BT) - 1, tl.cdiv(i_s * S, BT) - 1, -1): + last_idx = min(i_t * BT + BT, T) - 1 + # scalar decay + if USE_G: + if HEAD_FIRST: + b_g_last = tl.load(g + i_ng * T + last_idx) + else: + b_g_last = tl.load(g + (bos + last_idx) * H + i_h) + b_dh *= exp(b_g_last) + + if USE_GK: + if HEAD_FIRST: + p_gk_last = gk + (i_ng * T + last_idx) * K + i_k * BK + tl.arange(0, BK) + p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK) + else: + p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + + b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.) + b_dh *= exp(b_gk_last)[:, None] + + if USE_GV: + if HEAD_FIRST: + p_gv_last = gv + (i_ng * T + last_idx) * V + i_v * BV + tl.arange(0, BV) + p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV) + else: + p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.) + b_dh *= exp(b_gv_last)[None, :] + + if NS > 1: + if STORE_INITIAL_STATE_GRADIENT: + p_dhs = tl.make_block_ptr(dhs + (boh * H + i_h)*K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_dh0 = tl.make_block_ptr(dh0 + i_nh * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh += tl.load(p_dhs, boundary_check=(0, 1)).to(tl.float32) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_h( + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + h0: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + split_offsets: Optional[torch.LongTensor] = None, + split_indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64, + split_size: int = 256, + states_in_fp32: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + # B: batch size + # N: the actual number of sequences in the batch + # H: number of heads + # T: sequence length, can be variable across sequences + # S: split size, a multiple of chunk size + # BT: chunk size + S, BT = split_size, chunk_size + assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}" + if offsets is None: + N = B + NS = N * triton.cdiv(T, S) + else: + N = len(offsets) - 1 + NS = split_offsets[-1] + + # unreduced kv states per split + hs = k.new_empty(NS, H, K, V, dtype=torch.float) + # reduced states per split + hr = k.new_empty(NS, H, K, V, dtype=torch.float if states_in_fp32 else k.dtype) + ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None + # parallelized over splits + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * H) + chunk_fwd_kernel_h_split[grid]( + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + hs=hs, + hr=hr, + h0=h0, + ht=ht, + offsets=offsets, + split_indices=split_indices, + T=T, + S=S, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H) + chunk_fwd_kernel_h_reduction[grid]( + g=g, + gk=gk, + gv=gv, + hs=hs, + hr=hr, + ht=ht, + offsets=offsets, + split_offsets=split_offsets, + T=T, + S=S, + H=H, + K=K, + V=V, + BT=BT, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + return hr, ht + + +def chunk_bwd_dh( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + gk: torch.Tensor, + gv: torch.Tensor, + do: torch.Tensor, + h0: torch.Tensor, + dht: torch.Tensor, + scale: float, + offsets: Optional[torch.Tensor] = None, + split_offsets: Optional[torch.Tensor] = None, + split_indices: Optional[torch.Tensor] = None, + head_first: bool = True, + chunk_size: int = 64, + split_size: int = 256, + states_in_fp32: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + HQ = q.shape[1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + # B: batch size + # N: the actual number of sequences in the batch + # H: number of heads + # T: sequence length, can be variable across sequences + # S: split size, a multiple of chunk size + # BT: chunk size + S, BT = max(chunk_size, min(split_size, triton.next_power_of_2(T))), chunk_size + assert S % BT == 0, f"The `split_size` (got {S}) must be a multiple of `chunk_size` {BT}" + if offsets is None: + N = B + NS = N * triton.cdiv(T, S) + else: + N = len(offsets) - 1 + NS = split_offsets[-1] + # number of groups in GQA + NG = HQ // H + + dhs = q.new_empty(NS, HQ, K, V, dtype=torch.float) + dhr = q.new_empty(NS, HQ, K, V, dtype=torch.float if states_in_fp32 else k.dtype) + dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None + + # parallelized over splits + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), NS * HQ) + chunk_bwd_kernel_dh_split[grid]( + q=q, + g=g, + gk=gk, + gv=gv, + do=do, + dht=dht, + dhs=dhs, + dhr=dhr, + dh0=dh0, + offsets=offsets, + split_indices=split_indices, + scale=scale, + T=T, + S=S, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first, + ) + + def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * HQ) + chunk_bwd_kernel_dh_reduction[grid]( + g=g, + gk=gk, + gv=gv, + dhs=dhs, + dhr=dhr, + dh0=dh0, + offsets=offsets, + split_offsets=split_offsets, + T=T, + S=S, + HQ=HQ, + H=H, + K=K, + V=V, + BT=BT, + NG=NG, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + HEAD_FIRST=head_first + ) + return dhr, dh0 diff --git a/fla/ops/common/chunk_o.py b/fla/ops/common/chunk_o.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e99d1d28bebc49994deaef04c252be74b2d570 --- /dev/null +++ b/fla/ops/common/chunk_o.py @@ -0,0 +1,668 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp, safe_exp +from fla.utils import check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + # offset calculation + q += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + k += (i_bh * T*K) if HEAD_FIRST else ((bos * H + i_h) * K) + v += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) + o += (i_bh * T*V) if HEAD_FIRST else ((bos * H + i_h) * V) + h += ((i_bh * NT + i_t).to(tl.int64) * K*V) if HEAD_FIRST else ((i_tg * H + i_h).to(tl.int64) * K*V) + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_i = tl.arange(0, BT) + m_A = o_i[:, None] >= o_i[None, :] + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None, + 'USE_DW': lambda args: args['dw'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G', 'USE_DW'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dqkwg( + q, + k, + v, + h, + g, + do, + dh, + dq, + dk, + dg, + w, + dv, + dw, + offsets, + indices, + scale, + B: tl.constexpr, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_DW: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_G: + dg += i_k * B * H * T + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + h += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dq += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dk += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + + # for delta rule only + if USE_DW: + dw += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + w += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_ds = tl.zeros([BT, BT], dtype=tl.float32) + b_dg_last = tl.zeros([1,], dtype=tl.float32) if USE_G else None + b_dw = tl.zeros([BT, BK], dtype=tl.float32) if USE_DW else None + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + if USE_G: + b_dg_last += (tl.sum(b_h * b_dh)) + # [BT, BV] @ [BV, BT] -> [BT, BT] + b_ds += tl.dot(b_do, tl.trans(b_v)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + if USE_DW: + p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + if USE_DW and not USE_G: + p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + tl.debug_barrier() + o_i = tl.arange(0, BT) + p_q = tl.make_block_ptr(q, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + p_dq = tl.make_block_ptr(dq, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + if USE_G: + b_dg = tl.zeros([BT,], dtype=tl.float32) + g += i_bh * T if HEAD_FIRST else bos * H + i_h + dg += i_bh * T if HEAD_FIRST else bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) + b_dg_last *= exp(b_g_last) + + if USE_DW: + p_w = tl.make_block_ptr(w, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_dw = b_dw * exp(b_g)[:, None] + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + b_dg -= tl.sum(b_w * b_dw, axis=1) + + b_dq = b_dq * exp(b_g)[:, None] * scale + b_dg += tl.sum(b_dq * b_q, axis=1) + + b_dk = b_dk * safe_exp(-b_g + b_g_last)[:, None] + b_dg -= tl.sum(b_k * b_dk, axis=1) + b_dg_last += tl.sum(b_dk * b_k) + + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * safe_exp(b_g[:, None] - b_g[None, :]), 0) * scale + b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k)) + b_dg += tl.sum(b_ds2, axis=1) + b_dg -= tl.sum(b_ds2, axis=0) + + b_ds = b_ds.to(b_k.dtype) + # [BT, BK] + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) + p_dg = tl.make_block_ptr(dg, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue + # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last) + b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + else: + b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds, 0) + b_ds = b_ds.to(b_k.dtype) + b_dq += tl.dot(b_ds, b_k) + b_dk += tl.dot(tl.trans(b_ds), b_q) * scale + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv( + q, + k, + g, + do, + dv, + dh, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + dh += (i_bh * NT + i_t).to(tl.int64) * K*V if HEAD_FIRST else (i_tg * H + i_h).to(tl.int64) * K*V + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * s_g) + b_dv *= safe_exp(-b_g + b_g_last)[:, None] + + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + if USE_G: + b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) + p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_G': lambda args: args['g'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + do, + dv, + offsets, + indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T*K if HEAD_FIRST else (bos * H + i_h) * K + do += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T*V if HEAD_FIRST else (bos * H + i_h) * V + s_qk = K if HEAD_FIRST else H*K + s_vo = V if HEAD_FIRST else H*V + s_g = 1 if HEAD_FIRST else H + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k, (T, K), (s_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_q = tl.make_block_ptr(q, (K, T), (1, s_qk), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, b_q) + + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + p_g = tl.make_block_ptr(g, (T,), (s_g,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + + mask = (tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :]) + if USE_G: + b_A = tl.where(mask, b_A * safe_exp(b_g[None, :] - b_g[:, None]) * scale, 0).to(do.dtype.element_ty) + else: + b_A = tl.where(mask, b_A * scale, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (s_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *q.shape, v.shape[-1] + else: + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = k.shape[-1] ** -0.5 + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty_like(v) + + def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return o + + +def chunk_bwd_dv( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + NV = triton.cdiv(V, BV) + + dv = torch.empty_like(do) + grid = (NV, NT, B * H) + chunk_bwd_kernel_dv[grid]( + q, + k, + g, + do, + dv, + dh, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *k.shape, do.shape[-1] + else: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + # H100 can have larger block size + if check_shared_mem('hopper', k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_bwd_kernel_dv_local[grid]( + q, + k, + g, + do, + dv, + offsets, + indices, + scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: Optional[torch.Tensor] = None, + w: Optional[torch.Tensor] = None, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, + head_first: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + CONST_TILING = 64 if check_shared_mem() else 32 + BK = min(triton.next_power_of_2(K), CONST_TILING) + BV = min(triton.next_power_of_2(V), CONST_TILING) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dg = torch.empty(NK, *g.shape, dtype=torch.float32, device=g.device) if g is not None else None + dw = torch.empty_like(w) if w is not None else None + + grid = (NK, NT, B * H) + chunk_bwd_kernel_dqkwg[grid]( + q=q, + k=k, + v=v, + h=h, + g=g, + do=do, + dh=dh, + dv=dv, + w=w, + dw=dw, + dq=dq, + dk=dk, + dg=dg, + offsets=offsets, + indices=indices, + scale=scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + + if dg is not None: + dg = dg.sum(0) + return dq, dk, dw, dg diff --git a/fla/ops/common/chunk_scaled_dot_kkt.py b/fla/ops/common/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000000000000000000000000000000000..ff30664dce50a8869dd6198aaecea2ab6a171704 --- /dev/null +++ b/fla/ops/common/chunk_scaled_dot_kkt.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.common.utils import prepare_chunk_indices + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['H', 'K', 'BT', 'USE_OFFSETS'], +) +@triton.jit(do_not_specialize=['T']) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + A, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = tl.arange(0, BT) + + if HEAD_FIRST: + p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + if HEAD_FIRST: + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + else: + p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0) + if HEAD_FIRST: + p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (BT*H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], + head_first: bool = False, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]` if not `head_first` else `[B, H, T, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]` if not `head_first` else `[B, H, T]`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + head_first (bool): + If False, the input/output tensor is in the shape of `[B, T, H, K]`. + If True, the input/output tensor is in the shape of `[B, H, T, K]`. + Default: False + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` if not `head_first` else `[B, H, T, BT]`, + where `BT` is the chunk size. + """ + if head_first: + B, H, T, K = k.shape + else: + B, T, H, K = k.shape + BT = chunk_size + indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(indices) + A = torch.empty(B, *((H, T) if head_first else (T, H)), BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + A=A, + offsets=cu_seqlens, + indices=indices, + T=T, + H=H, + K=K, + BT=BT, + HEAD_FIRST=head_first + ) + return A diff --git a/fla/ops/common/fused_recurrent.py b/fla/ops/common/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..263de38d060716ec525a273d45eb1c3fe08ac4be --- /dev/null +++ b/fla/ops/common/fused_recurrent.py @@ -0,0 +1,575 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_cumsum +from fla.ops.utils.op import exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_FINAL_STATE': lambda args: args['ht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=["BK", "BV", "USE_GK", "USE_GV", "USE_G"], +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_fwd_kernel( + q, + k, + v, + g, + gk, + gv, + o, + h0, + ht, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + # indices + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_o = o + (i_k * B*H + i_nh) * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + i_nh * T + ((T-1) if REVERSE else 0) + if USE_GK: + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + else: + p_q = q + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_o = o + ((i_k * all + bos) + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + mask_k = (i_k * BK + tl.arange(0, BK)) < K + mask_v = (i_v * BV + tl.arange(0, BV)) < V + mask_h = mask_k[None, :] & mask_v[:, None] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[None, :]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[:, None]) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + p_q += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_o += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + if USE_GK: + p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + if USE_G: + p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) + + if STORE_FINAL_STATE: + p_ht = ht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +@triton.heuristics({ + 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, + 'STORE_INITIAL_STATE_GRADIENT': lambda args: args['dh0'] is not None, + 'USE_FINAL_STATE_GRADIENT': lambda args: args['dht'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BK', 'BV', 'USE_GK', 'USE_GV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def fused_recurrent_bwd_kernel( + q, + k, + v, + g, + gk, + gv, + h0, + do, + dq, + dk, + dv, + dht, + dh0, + offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + REVERSE: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_GV: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_INITIAL_STATE_GRADIENT: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_v, i_k, i_nh = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64), tl.program_id(2).to(tl.int64) + i_n, i_h = i_nh // H, i_nh % H + if USE_OFFSETS: + bos, eos = tl.load(offsets + i_n).to(tl.int64), tl.load(offsets + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if HEAD_FIRST: + p_k = k + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_do = do + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_dq = dq + (i_v * B*H + i_nh) * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_G: + p_g = g + i_nh * T + ((T-1) if REVERSE else 0) + if USE_GK: + p_gk = gk + i_nh * T*K + ((T-1) * K if REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + i_nh * T*V + ((T-1) * V if REVERSE else 0) + i_v * BV + tl.arange(0, BV) + else: + p_k = k + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dq = dq + ((i_v * all + bos) + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_G: + p_g = g + (bos + ((T-1) if REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T-1) if REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T-1) if REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + mask_k = i_k * BK + tl.arange(0, BK) < K + mask_v = i_v * BV + tl.arange(0, BV) < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h0 = h0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_h = b_h * exp(b_g) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_h = b_h * exp(b_gk[:, None]) + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_h = b_h * exp(b_gv[None, :]) + b_h += b_k[:, None] * b_v[None, :] + b_dq = b_h * b_do[None, :] + b_dq = tl.sum(b_dq, axis=1) * scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), mask=mask_k) + + p_k += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + p_v += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_do += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + p_dq += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + if USE_G: + p_g += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) + if USE_GK: + p_gk += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * K + if USE_GV: + p_gv += (-1 if REVERSE else 1) * (1 if HEAD_FIRST else H) * V + + # sync threads + tl.debug_barrier() + + if HEAD_FIRST: + p_q = q + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_k = k + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_v = v + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_do = do + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + p_dk = dk + (i_v * B*H + i_nh) * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + p_dv = dv + (i_k * B*H + i_nh) * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + i_nh * T + ((T - 1) if not REVERSE else 0) + if USE_GK: + p_gk = gk + i_nh * T*K + ((T - 1) * K if not REVERSE else 0) + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + i_nh * T*V + ((T - 1) * V if not REVERSE else 0) + i_v * BV + tl.arange(0, BV) + else: + p_q = q + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_k = k + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_v = v + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_do = do + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + p_dk = dk + ((i_v * all + bos) + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + p_dv = dv + ((i_k * all + bos) + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + if USE_G: + p_g = g + (bos + ((T - 1) if not REVERSE else 0)) * H + i_h + if USE_GK: + p_gk = gk + (bos + ((T - 1) if not REVERSE else 0)) * H*K + i_h * K + i_k * BK + tl.arange(0, BK) + if USE_GV: + p_gv = gv + (bos + ((T - 1) if not REVERSE else 0)) * H*V + i_h * V + i_v * BV + tl.arange(0, BV) + + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + if USE_FINAL_STATE_GRADIENT: + p_dht = dht + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_dh += tl.load(p_dht, mask=mask_h, other=0).to(tl.float32) + + for _ in range(T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_v, other=0).to(tl.float32) + b_dh += b_q[:, None] * b_do[None, :] + b_dk = tl.sum(b_dh * b_v[None, :], axis=1) + b_dv = tl.sum(b_dh * b_k[:, None], axis=0) + if USE_G: + b_g = tl.load(p_g).to(tl.float32) + b_dh *= exp(b_g) + if USE_GK: + b_gk = tl.load(p_gk, mask=mask_k, other=0).to(tl.float32) + b_dh *= exp(b_gk)[:, None] + if USE_GV: + b_gv = tl.load(p_gv, mask=mask_v, other=0).to(tl.float32) + b_dh *= exp(b_gv)[None, :] + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), mask=mask_k) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), mask=mask_v) + + p_q += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_k += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_v += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + p_do += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + p_dk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + p_dv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + if USE_G: + p_g += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) + if USE_GK: + p_gk += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * K + if USE_GV: + p_gv += (1 if REVERSE else -1) * (1 if HEAD_FIRST else H) * V + + if STORE_INITIAL_STATE_GRADIENT: + p_dh0 = dh0 + i_nh * K*V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + tl.store(p_dh0, b_dh.to(p_dh0.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + h0 = initial_state + if output_final_state: + ht = q.new_empty(N, H, K, V, dtype=torch.float32) + else: + ht = None + o = q.new_empty(NK, *v.shape, dtype=torch.float32) + + grid = (NV, NK, N * H) + fused_recurrent_fwd_kernel[grid]( + q, + k, + v, + g, + gk, + gv, + o, + h0, + ht, + offsets, + scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + o = o.sum(0) + return o, ht + + +def fused_recurrent_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + o: Optional[torch.Tensor] = None, + do: Optional[torch.Tensor] = None, + dht: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + N = B if offsets is None else len(offsets) - 1 + + BK, BV = min(K, 64), min(V, 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + + dq = q.new_empty(NV, *q.shape, dtype=torch.float32) + dk = q.new_empty(NV, *k.shape, dtype=torch.float32) + dv = q.new_empty(NK, *v.shape, dtype=torch.float32) + h0 = initial_state + dh0 = torch.empty_like(initial_state) if initial_state is not None else None + + grid = (NV, NK, N * H) + fused_recurrent_bwd_kernel[grid]( + q, + k, + v, + g, + gk, + gv, + h0, + do, + dq, + dk, + dv, + dht, + dh0, + offsets, + scale, + B=B, + T=T, + H=H, + K=K, + V=V, + BK=BK, + BV=BV, + USE_G=g is not None, + USE_GK=gk is not None, + USE_GV=gv is not None, + REVERSE=reverse, + HEAD_FIRST=head_first + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dg, dgk, dgv = None, None, None + if g is not None: + dg = chunk_global_cumsum( + (dq * q.float() - dk * k.float()).sum(-1), + reverse=not reverse, + offsets=offsets, + head_first=head_first + ) + if gk is not None: + dgk = chunk_global_cumsum( + dq * q.float() - dk * k.float(), + reverse=not reverse, + offsets=offsets, + head_first=head_first + ) + if gv is not None: + dgv = chunk_global_cumsum( + do.float() * o.float() - dv * v.float(), + reverse=not reverse, + offsets=offsets, + head_first=head_first + ) + + return dq, dk, dv, dg, dgk, dgv, dh0 + + +class FusedRecurrentFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + offsets: Optional[torch.LongTensor] = None, + head_first: bool = True + ): + o, ht = fused_recurrent_fwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + offsets=offsets, + head_first=head_first + ) + ctx.save_for_backward(q, k, v, g, gk, gv, initial_state, o) + ctx.scale = scale + ctx.reverse = reverse + ctx.offsets = offsets + ctx.head_first = head_first + return o.to(q.dtype), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + q, k, v, g, gk, gv, initial_state, o = ctx.saved_tensors + # not supported yet. + if dht is not None: + if not dht.eq(0).all(): + if g is not None: + assert g.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + if gk is not None: + assert gk.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + if gv is not None: + assert gv.requires_grad is False, "Cannot load final state gradient and use gates at the same time" + dq, dk, dv, dg, dgk, dgv, dh0 = fused_recurrent_bwd( + q=q, + k=k, + v=v, + g=g, + gk=gk, + gv=gv, + o=o, + do=do, + dht=dht, + scale=ctx.scale, + initial_state=initial_state, + reverse=ctx.reverse, + offsets=ctx.offsets, + head_first=ctx.head_first + ) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, dgk, dgv, None, dh0, None, None, None, None + + +def fused_recurrent( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + gk: Optional[torch.Tensor] = None, + gv: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +): + if scale is None: + scale = k.shape[-1] ** -0.5 + return FusedRecurrentFunction.apply( + q, + k, + v, + g, + gk, + gv, + scale, + initial_state, + output_final_state, + reverse, + cu_seqlens, + head_first + ) diff --git a/fla/ops/common/utils.py b/fla/ops/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c61cf9a36b4a79578f8692070bce68a6d39830b8 --- /dev/null +++ b/fla/ops/common/utils.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from fla.utils import tensor_cache + + +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [4, 8, 16, 32] + ], + key=['B'], +) +@triton.jit +def prepare_position_ids_kernel( + y, + offsets, + B: tl.constexpr +): + i_n = tl.program_id(0) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + + o = tl.arange(0, B) + for i in range(0, tl.cdiv(T, B) * B, B): + o_i = o + i + tl.store(y + bos + o_i, o_i, o_i < T) + + +@tensor_cache +def prepare_lens(offsets: torch.LongTensor) -> torch.LongTensor: + return offsets[1:] - offsets[:-1] + + +@tensor_cache +def prepare_position_ids(offsets: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=offsets.dtype, device=offsets.device) for n in prepare_lens(offsets).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(position_ids: torch.LongTensor) -> torch.LongTensor: + return position_ids.eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(offsets: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(offsets) + return torch.stack([prepare_sequence_ids(position_ids), position_ids], 1).to(offsets) + + +@tensor_cache +def prepare_chunk_indices( + offsets: torch.LongTensor, + chunk_size: int +) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(offsets), chunk_size).tolist()]) + return torch.stack([prepare_sequence_ids(indices), indices], 1).to(offsets) + + +@tensor_cache +def prepare_chunk_offsets( + offsets: torch.LongTensor, + chunk_size: int +) -> torch.LongTensor: + return torch.cat([offsets.new_tensor([0]), triton.cdiv(prepare_lens(offsets), chunk_size)]).cumsum(-1) diff --git a/fla/ops/forgetting_attn/__init__.py b/fla/ops/forgetting_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e62c741d464f01b5c0c6707671061293b9d48644 --- /dev/null +++ b/fla/ops/forgetting_attn/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from .parallel import parallel_forgetting_attn + +__all__ = [ + 'parallel_forgetting_attn' +] diff --git a/fla/ops/forgetting_attn/__pycache__/__init__.cpython-311.pyc b/fla/ops/forgetting_attn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2563167f8890b52f83041f1e80dfceff02340f07 Binary files /dev/null and b/fla/ops/forgetting_attn/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/forgetting_attn/__pycache__/parallel.cpython-311.pyc b/fla/ops/forgetting_attn/__pycache__/parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d93808f1ecdc3042d380c9e43516b78e6a399e28 Binary files /dev/null and b/fla/ops/forgetting_attn/__pycache__/parallel.cpython-311.pyc differ diff --git a/fla/ops/forgetting_attn/parallel.py b/fla/ops/forgetting_attn/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..88fea7f29b238e488848711ed894cb6cae7ea91b --- /dev/null +++ b/fla/ops/forgetting_attn/parallel.py @@ -0,0 +1,708 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import rearrange, reduce + +from fla.ops.common.utils import prepare_chunk_indices +from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum +from fla.ops.utils.op import div, exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4, 5] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit +def parallel_forgetting_attn_fwd_kernel( + q, + k, + v, + g, + o, + lse, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT,] + b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + # [BT, BV] + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + b_m = tl.full([BT], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([BT], dtype=tl.float32) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_gq[:, None] - b_gk[None, :] + b_s = tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf')) + + # [BT] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + for i_s in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + + b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32) + b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0. + # [BT, BS] + b_s = tl.dot(b_q, b_k) + b_gq[:, None] + (b_gn - b_gk)[None, :] + + b_gq += b_gn - b_gp + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [BT, BS] + b_p = exp(b_s - b_m[:, None]) + # [BT] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [BT, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + + b_o = div(b_o, b_acc[:, None]) + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit +def parallel_forgetting_attn_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4] + ([8] if check_shared_mem('hopper') else []) + for num_stages in [2, 3, 4] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_forgetting_attn_bwd_kernel_dq( + q, + k, + v, + g, + lse, + delta, + do, + dq, + dg, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_dq = tl.make_block_ptr(dq + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT] + b_gq = tl.load(p_g, boundary_check=(0,)).to(tl.float32) + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + # [BT] + b_dg = tl.zeros([BT,], dtype=tl.float32) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_k = i_s + tl.arange(0, BS) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)) + # [BT, BS] + b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] - b_gk[None, :] + b_p = exp(tl.where(o_q[:, None] >= o_k[None, :], b_s, float('-inf'))) + + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + # [BT] + b_dg += tl.sum(b_ds, 1) + + for i_s in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BS,] + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + + b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32) + b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0. + # [BT, BS] + b_s = tl.dot(b_q, b_k) + (b_gq - b_lse)[:, None] + (b_gn - b_gk)[None, :] + b_p = exp(b_s) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp - b_delta[:, None]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + # [BT] + b_dg += tl.sum(b_ds, 1) + + b_gq += b_gn - b_gp + + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=['B', 'H', 'G', 'K', 'V', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_forgetting_attn_bwd_kernel_dkv( + q, + k, + v, + g, + lse, + delta, + do, + dk, + dv, + dg, + offsets, + indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_hq = i_bh // HQ, i_bh % HQ + i_h = i_hq // G + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + i_n = i_b + bos, eos = i_n * T, i_n * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_gk = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + p_dk = tl.make_block_ptr(dk + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_t * BT, 0), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * HQ + i_hq), (T,), (HQ,), (i_t * BT,), (BT,), (0,)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + # [BT] + b_gk = tl.load(p_gk, boundary_check=(0,)).to(tl.float32) + b_dg = tl.zeros([BT,], dtype=tl.float32) + + o_k = i_t * BT + tl.arange(0, BT) + m_k = o_k < T + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + + m_q = o_q < T + m_s = (o_k[:, None] <= o_q[None, :]) & m_k[:, None] & m_q[None, :] + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) - b_gk[:, None] + (b_gq - b_lse)[None, :] + b_p = tl.where(m_s, exp(b_s), 0) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + # [BT] + b_dg -= tl.sum(b_ds, 1) + + b_gk -= tl.load(g + (bos + min((i_t + 1) * BT, T) - 1) * HQ + i_hq).to(tl.float32) + for i_s in range((i_t + 1) * BT, T, BS): + p_q = tl.make_block_ptr(q + (bos * HQ + i_hq) * K, (T, K), (HQ*K, 1), (i_s, 0), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do + (bos * HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + p_lse = tl.make_block_ptr(lse + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_delta = tl.make_block_ptr(delta + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + p_gq = tl.make_block_ptr(g + bos * HQ + i_hq, (T,), (HQ,), (i_s,), (BS,), (0,)) + + # [BS] + o_q = i_s + tl.arange(0, BS) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_lse = tl.load(p_lse, boundary_check=(0,)) + b_delta = tl.load(p_delta, boundary_check=(0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + + b_gn = tl.load(g + (bos + min(i_s + BS, T) - 1) * HQ + i_hq).to(tl.float32) + b_gp = tl.load(g + (bos + i_s - 1) * HQ + i_hq).to(tl.float32) if i_s % BT > 0 else 0. + # [BT, BS] + b_s = tl.dot(b_k, tl.trans(b_q)) - (b_gk + b_gp)[:, None] + (b_gq - b_lse)[None, :] + b_p = exp(b_s) + # [BT, BS] @ [BS, BV] -> [BT, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BT, BV] @ [BV, BS] -> [BT, BS] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BT, BS] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BT, BS] @ [BS, BK] -> [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + # [BT] + b_dg -= tl.sum(b_ds, 1) + + b_gk -= b_gn - b_gp + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +def parallel_forgetting_attn_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BK = max(16, triton.next_power_of_2(K)) + assert V <= 256, "V must be less than or equal to 256" + if check_shared_mem('hopper'): + BS = min(64, max(16, triton.next_power_of_2(T))) + else: + BS = min(32, max(16, triton.next_power_of_2(T))) + BV = min(256, max(16, triton.next_power_of_2(V))) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + grid = (NV, NT, B * HQ) + parallel_forgetting_attn_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_forgetting_attn_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float) + parallel_forgetting_attn_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_forgetting_attn_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + scale: float = None, + chunk_size: int = 128, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + HQ = q.shape[2] + G = HQ // H + BT = chunk_size + BS = min(32, max(16, triton.next_power_of_2(T))) + BK = max(16, triton.next_power_of_2(K)) + BV = max(16, triton.next_power_of_2(V)) + NV = triton.cdiv(V, BV) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + delta = parallel_forgetting_attn_bwd_preprocess(o, do) + dq = q.new_empty(B, T, HQ, K, dtype=q.dtype) + dk = q.new_empty(B, T, HQ, K, dtype=k.dtype if H == HQ else torch.float) + dv = q.new_empty(B, T, HQ, V, dtype=v.dtype if H == HQ else torch.float) + dg = q.new_empty(g.shape, dtype=torch.float) + # NOTE: the original `dg` can be destroyed during autotuning + # this is [a known triton issue](https://github.com/triton-lang/triton/issues/5082), which will be fixed in 3.3 (?) + # so we need to make a copy of `dg` + dg2 = q.new_empty(g.shape, dtype=torch.float) + grid = (NV, NT, B * HQ) + parallel_forgetting_attn_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + g=g, + lse=lse, + delta=delta, + do=do, + dq=dq, + dg=dg, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + parallel_forgetting_attn_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + g=g, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + dg=dg2, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV + ) + dk = reduce(dk, 'b t (h g) k -> b t h k', g=G, reduction='sum') + dv = reduce(dv, 'b t (h g) v -> b t h v', g=G, reduction='sum') + dg = dg.add_(dg2) + return dq, dk, dv, dg + + +@torch.compile +class ParallelForgettingAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, offsets): + ctx.dtype = q.dtype + if check_shared_mem('hopper'): + chunk_size = min(128, max(16, triton.next_power_of_2(q.shape[1]))) + else: + chunk_size = min(64, max(16, triton.next_power_of_2(q.shape[1]))) + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None + + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=False) + o, lse = parallel_forgetting_attn_fwd( + q=q, + k=k, + v=v, + g=g, + scale=scale, + chunk_size=chunk_size, + offsets=offsets, + indices=indices + ) + ctx.save_for_backward(q, k, v, g, o, lse) + ctx.chunk_size = chunk_size + ctx.offsets = offsets + ctx.indices = indices + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, g, o, lse = ctx.saved_tensors + dq, dk, dv, dg = parallel_forgetting_attn_bwd( + q=q, + k=k, + v=v, + g=g, + o=o, + lse=lse, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=ctx.offsets, + indices=ctx.indices + ) + dg = chunk_global_cumsum(dg, reverse=True, head_first=False, offsets=ctx.offsets) + return dq.to(q), dk.to(k), dv.to(v), dg.to(g), None, None, None, None, None, None, None, None + + +def parallel_forgetting_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA will be applied if HQ is divisible by H. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + Forget gates (in **log space**) of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if g is not None: + g = g.float() + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + g = rearrange(g, 'b h t -> b t h') + o = ParallelForgettingAttentionFunction.apply(q, k, v, g, scale, cu_seqlens) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/generalized_delta_rule/dplr/__init__.py b/fla/ops/generalized_delta_rule/dplr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6de2928ca88abc25dc3156c4dc4fcb13ace180d --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/__init__.py @@ -0,0 +1,7 @@ +from .chunk import chunk_dplr_delta_rule +from .fused_recurrent import fused_recurrent_dplr_delta_rule + +__all__ = [ + 'chunk_dplr_delta_rule', + 'fused_recurrent_dplr_delta_rule' +] diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..496f658775f0da48d1c4fedabb18e3e58fe833c4 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fba08c276c30d9b31682563273a9bc0cc0595c1a Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7260f407aaf634c543a737572bace2b89ddef595 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_bwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a1b87474e3e488592b4aff11d90b352f516b1f5 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_A_fwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa6a2aea69c84179e0fefa0d4cfb3c9426211821 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_bwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae0edfd522112e34e5f728f668132f89ea0f7009 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_h_fwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c849169451c3d9a1e0d6bf516d305630633c1de4 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_bwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d43bbb569cec777d6c07e53829029ccf7be4ee48 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/chunk_o_fwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c8fd75c3634048df543a8926c45c6c402c7beba Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac5ccfbf62aee829acbef0418762a79022681f60 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_bwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-311.pyc b/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce50bccc831ee99cd899900c900335a761193af4 Binary files /dev/null and b/fla/ops/generalized_delta_rule/dplr/__pycache__/wy_fast_fwd.cpython-311.pyc differ diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a17bcfb2bad98fbf5df1dab70f21a86a59f111 --- /dev/null +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -0,0 +1,464 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils.op import exp +from fla.utils import check_shared_mem, use_cuda_graph + +BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BV', 'BT'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=['T']) +def chunk_dplr_bwd_kernel_dAu( + v, + do, + v_new, + A_qb, + dA_qk, + dA_qb, + dv_new, + offsets, + indices, + scale: tl.constexpr, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA_qk = tl.zeros([BT, BT], dtype=tl.float32) + b_dA_qb = tl.zeros([BT, BT], dtype=tl.float32) + + if HEAD_FIRST: + p_A_qb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_A_qb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + + b_A_qb = tl.load(p_A_qb, boundary_check=(0, 1)) + # causal mask + b_A_qb = tl.where(tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :], b_A_qb, 0.).to(b_A_qb.dtype) + + for i_v in range(tl.cdiv(V, BV)): + if HEAD_FIRST: + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (V, T), (1, V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_dv_new = tl.make_block_ptr(dv_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + else: + p_do = tl.make_block_ptr(do + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_v_new = tl.make_block_ptr(v_new + (bos*H + i_h) * V, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + p_dv_new = tl.make_block_ptr(dv_new + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_dA_qk += tl.dot(b_do, b_v) + b_dA_qb += tl.dot(b_do, b_v_new) + b_dv_new = tl.dot(tl.trans(b_A_qb), b_do) + # for recurrent + tl.store(p_dv_new, b_dv_new.to(p_dv_new.dtype.element_ty), boundary_check=(0, 1)) + + if HEAD_FIRST: + p_dA_qk = tl.make_block_ptr(dA_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA_qb = tl.make_block_ptr(dA_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + else: + p_dA_qk = tl.make_block_ptr(dA_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + p_dA_qb = tl.make_block_ptr(dA_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA_qk = tl.where(m_s, b_dA_qk * scale, 0.) + tl.store(p_dA_qk, b_dA_qk.to(p_dA_qk.dtype.element_ty), boundary_check=(0, 1)) + b_dA_qb = tl.where(m_s, b_dA_qb * scale, 0.) + tl.store(p_dA_qb, b_dA_qb.to(p_dA_qb.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_o_kernel( + v, + v_new, + h, + do, + dh, + dk, + db, + w, + dq, + dv, + dw, + gk, + dgk_last, + k, + b, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + v_new += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + h += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K * V + dk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + db += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + b += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dw += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + dq += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + w += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + # CHECK HEAD_FIRST is FALSE + dgk_last += (i_bh * NT + i_t) * K if HEAD_FIRST else (i_tg * H + i_h) * K + gk += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_db = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk_last = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_dgk_last += tl.sum((b_h * b_dh).to(tl.float32), axis=0) + + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + # [BT, BV] @ [BV, BK] -> [BT, BK] + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + b_db += tl.dot(b_v_new, b_dh.to(b_v_new.dtype)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + m_k = (i_k*BK+tl.arange(0, BK)) < K + last_idx = min(i_t * BT + BT, T) - 1 + b_gk_last = tl.load(gk + last_idx * stride_qk + i_k*BK + tl.arange(0, BK), mask=m_k, other=float('-inf')) + b_dgk_last *= exp(b_gk_last) + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_b = tl.make_block_ptr(b, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_b = tl.load(p_b, boundary_check=(0, 1)) + b_dgk_last += tl.sum(b_k * b_dk, axis=0) + b_dgk_last += tl.sum(b_b * b_db, axis=0) + tl.store(dgk_last + tl.arange(0, BK) + i_k * BK, b_dgk_last, mask=m_k) + + p_dw = tl.make_block_ptr(dw, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16, 32] + for num_stages in [2, 3, 4] + for BK in BK_LIST + for BV in BK_LIST + ], + key=['BT', 'BK', 'BV'], + use_cuda_graph=use_cuda_graph, +) +@triton.jit +def chunk_dplr_bwd_kernel_dv( + A_qk, + kg, + do, + dv, + dh, + offsets, + indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if USE_OFFSETS: + i_tg = i_t + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + + # offset calculation + A_qk += i_bh * T * BT if HEAD_FIRST else (bos * H + i_h) * BT + do += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + dv += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + kg += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + dh += (i_bh * NT + i_t) * K*V if HEAD_FIRST else (i_tg * H + i_h) * K*V + + stride_qk = K if HEAD_FIRST else H*K + stride_vo = V if HEAD_FIRST else H*V + stride_A = BT if HEAD_FIRST else H*BT + + for i_k in range(tl.cdiv(K, BK)): + p_dh = tl.make_block_ptr(dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + p_kg = tl.make_block_ptr(kg, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + b_kg = tl.load(p_kg, boundary_check=(0, 1)) + b_dv += tl.dot(b_kg, b_dh.to(b_kg.dtype)) + + p_Aqk = tl.make_block_ptr(A_qk, (BT, T), (1, stride_A), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], tl.load(p_Aqk, boundary_check=(0, 1)), 0) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv += tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_dplr_bwd_dv( + A_qk: torch.Tensor, + kg: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, K, V = *kg.shape, do.shape[-1] + else: + B, T, H, K, V = *kg.shape, do.shape[-1] + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + dv = torch.empty_like(do) + + def grid(meta): return ( + triton.cdiv(V, meta['BV']), + NT, + B * H + ) + chunk_dplr_bwd_kernel_dv[grid]( + A_qk=A_qk, + kg=kg, + do=do, + dv=dv, + dh=dh, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + HEAD_FIRST=head_first + ) + return dv + + +def chunk_dplr_bwd_o( + k: torch.Tensor, + b: torch.Tensor, + v: torch.Tensor, + v_new: torch.Tensor, + gk: torch.Tensor, + do: torch.Tensor, + h: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + w: torch.Tensor, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + scale: float = 1.0, + head_first: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if head_first: + B, H, T, K, V = *w.shape, v.shape[-1] + else: + B, T, H, K, V = *w.shape, v.shape[-1] + + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + BK = min(triton.next_power_of_2(K), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + BV = min(triton.next_power_of_2(V), 64) if check_shared_mem() else min(triton.next_power_of_2(K), 32) + NK = triton.cdiv(K, BK) + dq = torch.empty_like(k) + dk = torch.empty_like(k) + dw = torch.empty_like(w) + db = torch.empty_like(b) + grid = (NK, NT, B * H) + + dgk_last = torch.empty(B, H, NT, K, dtype=torch.float, device=w.device) if head_first \ + else torch.empty(B, NT, H, K, dtype=torch.float, device=w.device) + + chunk_dplr_bwd_o_kernel[grid]( + k=k, + b=b, + v=v, + v_new=v_new, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + db=db, + dgk_last=dgk_last, + w=w, + dv=dv, + dw=dw, + gk=gk, + offsets=offsets, + indices=indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + HEAD_FIRST=head_first, + ) + return dq, dk, dw, db, dgk_last + + +def chunk_dplr_bwd_dAu( + v: torch.Tensor, + v_new: torch.Tensor, + do: torch.Tensor, + A_qb: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> torch.Tensor: + if head_first: + B, H, T, V = v.shape + else: + B, T, H, V = v.shape + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + if check_shared_mem('ampere'): # A100 + BV = min(triton.next_power_of_2(V), 128) + elif check_shared_mem('ada'): # 4090 + BV = min(triton.next_power_of_2(V), 64) + else: + BV = min(triton.next_power_of_2(V), 32) + + grid = (NT, B * H) + dA_qk = torch.empty(B, H, T, BT, dtype=torch.float, device=v.device) if head_first \ + else torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dA_qb = torch.empty(B, H, T, BT, dtype=torch.float, device=v.device) if head_first \ + else torch.empty(B, T, H, BT, dtype=torch.float, device=v.device) + dv_new = torch.empty_like(v_new) + chunk_dplr_bwd_kernel_dAu[grid]( + v=v, + do=do, + v_new=v_new, + A_qb=A_qb, + dA_qk=dA_qk, + dA_qb=dA_qb, + dv_new=dv_new, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + HEAD_FIRST=head_first + ) + return dv_new, dA_qk, dA_qb diff --git a/fla/ops/linear_attn/__init__.py b/fla/ops/linear_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a981054aaf9ab98b30ac08fa525bde73e68e7e4 --- /dev/null +++ b/fla/ops/linear_attn/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_linear_attn +from .fused_chunk import fused_chunk_linear_attn +from .fused_recurrent import fused_recurrent_linear_attn + +__all__ = [ + 'chunk_linear_attn', + 'fused_chunk_linear_attn', + 'fused_recurrent_linear_attn' +] diff --git a/fla/ops/linear_attn/__pycache__/__init__.cpython-311.pyc b/fla/ops/linear_attn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..702d8b4c2a992fc61e600cf5fac6075333462827 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/chunk.cpython-311.pyc b/fla/ops/linear_attn/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e3d87cd840ff47d1a0c0dc3210480ae3d203e7a Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/fused_chunk.cpython-311.pyc b/fla/ops/linear_attn/__pycache__/fused_chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37032826ed274dafa7e3f8b397ddcf7ec3cb1681 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/fused_chunk.cpython-311.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52e58cf83df119386f70e36b0febf164219bdc3f Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/linear_attn/__pycache__/utils.cpython-311.pyc b/fla/ops/linear_attn/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b90a478f5b410c49ce780215f6c62021b7af64 Binary files /dev/null and b/fla/ops/linear_attn/__pycache__/utils.cpython-311.pyc differ diff --git a/fla/ops/linear_attn/chunk.py b/fla/ops/linear_attn/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..8283e707923389e5c0f4e8294f7c491277f7243d --- /dev/null +++ b/fla/ops/linear_attn/chunk.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Yu Zhang, Songlin Yang + +from typing import Optional, Tuple + +import torch + +from fla.ops.linear_attn.utils import normalize_output +from fla.ops.simple_gla import chunk_simple_gla + + +@torch.compiler.disable +def chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + normalize: bool = True, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for the linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + + if scale is None: + scale = k.shape[-1] ** -0.5 + + o, final_state = chunk_simple_gla( + q=q, + k=k, + v=v, + scale=scale, + g=None, + initial_state=initial_state, + output_final_state=output_final_state, + head_first=head_first + ) + if normalize: + o = normalize_output(q * scale, k, o) + return o, final_state diff --git a/fla/ops/linear_attn/fused_chunk.py b/fla/ops/linear_attn/fused_chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..bfcc1212a534aa3debb5bb2d1cdbbce5f95f06e4 --- /dev/null +++ b/fla/ops/linear_attn/fused_chunk.py @@ -0,0 +1,318 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +@triton.jit +def fused_chunk_linear_attn_fwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + o, # output [B, H, T, V] + h0, + ht, + scale, + B, # batch size + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + o_i = tl.arange(0, BT) + + # [BT, BT] + m_s = o_i[:, None] >= o_i[None, :] + # [BK, BV] + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # make block pointers + p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (0, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, 0), (BK, BT), (0, 1)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (0, i_v * BV), (BT, BV), (1, 0)) + + if USE_INITIAL_STATE: + p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_q, b_k, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False) + if CHECK and i == 0: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + else: + b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_k, b_v, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + p_q = tl.advance(p_q, (BT, 0)) + p_k = tl.advance(p_k, (0, BT)) + p_v = tl.advance(p_v, (BT, 0)) + p_o = tl.advance(p_o, (BT, 0)) + + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def fused_chunk_linear_attn_bwd_kernel( + q, # query [B, H, T, K] + k, # key [B, H, T, V] + v, # value [B, H, T, V] + do, # gradient of output [B, H, T, V] + dq, # gradient of query [NV, B, H, T, K] + dk, # gradient of key [NV, B, H, T, K] + dv, # gradient of value [NK, B, H, T, V] + h0, # initial state of the chunk [B, H, K, V] + scale, # K ** -0.5 + B, # B + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BT: tl.constexpr, # BLOCK SIZE along the sequence dimension, a.k.a. chunk size + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, + CHECK: tl.constexpr +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + o_i = tl.arange(0, BT) + + m_s = o_i[:, None] >= o_i[None, :] + # [BV, BK] + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + p_h = tl.make_block_ptr(h0 + i_bh * K * V, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + b_h = tl.load(p_h, boundary_check=(0, 1)).to(tl.float32) + + for i in range(0, tl.cdiv(T, BT)): + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (V, T), (1, V), (i_v * BV, i * BT), (BV, BT), (0, 1)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dq = tl.make_block_ptr(dq + (i_bh + i_v*B*H) * T*K, (T, K), (K, 1), (i*BT, i_k*BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [V, BT] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, V] + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_ds = tl.dot(b_do, b_v, allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, BK] + b_dq = tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False) + # [BV, BK] + if CHECK and i == 0: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + else: + b_dq += tl.dot(b_do, b_h.to(b_do.dtype), allow_tf32=False) + b_h = b_h + tl.dot(b_v, b_k, allow_tf32=False) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + # sync threads + b_h = None + tl.debug_barrier() + # [BK, BV] + b_dh = tl.zeros([BK, BV], dtype=tl.float32) + m_s = o_i[:, None] <= o_i[None, :] + for i in range(1, tl.cdiv(T, BT) + 1): + p_q = tl.make_block_ptr(q + i_bh * T*K, (K, T), (1, K), (i_k * BK, T - i * BT), (BK, BT), (0, 1)) + p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (T - i * BT, i_k * BK), (BT, BK), (1, 0)) + p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do + i_bh * T*V, (T, V), (V, 1), (T - i * BT, i_v * BV), (BT, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_bh+i_v*B*H) * T*K, (T, K), (K, 1), (T - i*BT, i_k*BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_bh+i_k*B*H) * T*V, (T, V), (V, 1), (T - i*BT, i_v*BV), (BT, BV), (1, 0)) + # [BK, BT] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # [BT, BT] + b_s = tl.dot(b_k, b_q, allow_tf32=False) + b_s = tl.where(m_s, b_s, 0).to(b_q.dtype) + # [BT, BT] + b_ds = tl.dot(b_v, tl.trans(b_do), allow_tf32=False) + b_ds = tl.where(m_s, b_ds, 0).to(b_q.dtype) + # [BT, BK] + b_dk = tl.dot(b_ds, tl.trans(b_q), allow_tf32=False) + # [BT, BV] + b_dv = tl.dot(b_s, b_do, allow_tf32=False) + if CHECK and i == 1: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + else: + b_dk += tl.dot(b_v, tl.trans(b_dh).to(b_v.dtype), allow_tf32=False) + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype), allow_tf32=False) + b_dh += tl.dot(b_q, b_do, allow_tf32=False) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +class FusedChunkLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, scale, initial_state, output_final_state): + B, H, T, K, V = *k.shape, v.shape[-1] + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 + num_stages = 1 + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V, dtype=torch.float) if output_final_state else None + # the bug still exists even for Triton 2.2 on H100 GPUs + # so we always enable initial checks + CHECK = True + if version.parse(triton.__version__) < version.parse('2.2.0'): + import warnings + warnings.warn( + "Triton<2.2.0 detected for running this kernel, " + "which is known to have some weird compiler issues (refer to https://github.com/openai/triton/issues/2852) " + "that lead to significant precision loss. " + "We've add some initial condition checks to resolve this, sadly at the sacrifice of the speed. " + "For optimal performance, it is recommended to install Triton>=2.2.0 (if possible)." + ) + CHECK = True + + grid = (NV, NK, B * H) + fused_chunk_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=output_final_state, + CHECK=CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + o = o.sum(0) if NK > 1 else o[0] + + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + ctx.CHECK = CHECK + return o.to(q.dtype), final_state + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K, V = *k.shape, v.shape[-1] + scale = ctx.scale + + BT = 64 + BK, BV = min(triton.next_power_of_2(K), 64), min(triton.next_power_of_2(V), 64) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 4 + num_stages = 1 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_chunk_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + scale, + B=B, H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + CHECK=ctx.CHECK, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None + + +def fused_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = True, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + scale (Optional[int]): + Scale factor for linear attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[B, H, K, V]`. Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[B, H, K, V]`. Default: `False`. + normalize (bool): + Whether to normalize the output. Default: `True`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + final_state (torch.Tensor): + Final state of shape `[B, H, K, V]` if `output_final_state=True` else `None` + """ + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, final_state = FusedChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/linear_attn/fused_recurrent.py b/fla/ops/linear_attn/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..b50b8c7bfb470b69be5ba3327de24ed07ffa974d --- /dev/null +++ b/fla/ops/linear_attn/fused_recurrent.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.linear_attn.utils import normalize_output +from fla.utils import input_guard + + +@triton.jit +def fused_recurrent_linear_attn_fwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + o, # output [B, H, L, V] + h0, + ht, # final hidden state [B, H, K, V] + + s_k_h, # stride size: L * K + s_v_h, # stride size: L * V + + scale, + B, # batch size + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + STORE_FINAL_STATE: tl.constexpr, # whether to store final state +): + # indices + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + + mask_bk = (i_k * BK + tl.arange(0, BK)) < K + mask_bv = (i_v * BV + tl.arange(0, BV)) < V + mask_kv = mask_bk[None, :] & mask_bv[:, None] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + + if USE_INITIAL_STATE: + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + + b_h += b_k[None, :] * b_v[:, None] + b_o = b_h * b_q[None, :] + b_o = tl.sum(b_o, axis=1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv) + + p_q += K + p_k += K + p_o += V + p_v += V + + if STORE_FINAL_STATE: + p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None]) + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv) + + +# Similar to Algorithm1 of https://arxiv.org/abs/2006.16236 +@triton.jit +def fused_recurrent_linear_attn_bwd_kernel( + q, # query [B, H, L, K] + k, # key [B, H, L, V] + v, # value [B, H, L, V] + + do, # gradient of output [B, H, L, V] + dq, # gradient of query [NV, B, H, L, K] + dk, # gradient of key [NV, B, H, L, K] + dv, # gradient of value [NK, B, H, L, V] + h0, # initial hidden state initialization [B, H, K, V] + + s_k_h, # stride size: L * K + s_v_h, # stride size: L * V + scale, # K ** -0.5 + + B, # B + H, # H + T, # T + K: tl.constexpr, # K + V: tl.constexpr, # V + BK: tl.constexpr, # BLOCK SIZE along the K dimension + BV: tl.constexpr, # BLOCK SIZE along the V dimension + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state +): + i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + + p_dq = dq + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + mask_bk = i_k * BK + tl.arange(0, BK) < K + mask_bv = i_v * BV + tl.arange(0, BV) < V + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + if USE_INITIAL_STATE: + mask_kv = mask_bk[:, None] & mask_bv[None, :] + p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :]) + b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32) + + for _ in range(0, T): + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + + b_h += b_k[:, None] * b_v[None, :] + _d_q = b_h * b_do[None, :] + d_q = tl.sum(_d_q, axis=1) * scale + tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk) + + p_k += K + p_do += V + p_v += V + p_dq += K + + # sync threads + tl.debug_barrier() + + p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_do = do + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + p_dk = dk + (i_bh + i_v * B * H) * s_k_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K + p_dv = dv + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V + d_h = tl.zeros([BK, BV], dtype=tl.float32) + + for _ in range(T): + b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32) + b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale + b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32) + d_h += b_q[:, None] * b_do[None, :] + d_k = tl.sum(d_h * b_v[None, :], axis=1) + d_v = tl.sum(d_h * b_k[:, None], axis=0) + + tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk) + tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv) + + p_do -= V + p_q -= K + p_k -= K + p_v -= V + p_dk -= K + p_dv -= V + + +class FusedRecurrentLinearAttentionFunction(torch.autograd.Function): + + @staticmethod + @input_guard + def forward(ctx, q, k, v, scale, initial_state=None, output_final_state=False): + B, H, T, K = q.shape + V = v.shape[-1] + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 1 + num_stages = 1 + + o = q.new_empty(NK, B, H, T, V) + final_state = q.new_empty(B, H, K, V) if output_final_state else None + + grid = (NV, NK, B * H) + fused_recurrent_linear_attn_fwd_kernel[grid]( + q, k, v, o, initial_state, final_state, + q.stride(1), + v.stride(1), scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + STORE_FINAL_STATE=final_state is not None, + num_warps=num_warps, + num_stages=num_stages + ) + + o = o.sum(0) + ctx.save_for_backward(q, k, v, initial_state) + ctx.scale = scale + return o, final_state + + @staticmethod + @input_guard + def backward(ctx, do, dht=None): + q, k, v, initial_state = ctx.saved_tensors + B, H, T, K = q.shape + V = v.shape[-1] + scale = ctx.scale + + BK, BV = min(K, 32), min(V, 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + num_warps = 1 + num_stages = 1 + + dq = q.new_empty(NV, B, H, T, K) + dk = q.new_empty(NV, B, H, T, K) + dv = q.new_empty(NK, B, H, T, V) + grid = (NV, NK, B * H) + + fused_recurrent_linear_attn_bwd_kernel[grid]( + q, k, v, do, dq, dk, dv, initial_state, + q.stride(1), + v.stride(1), + scale, + B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV, + USE_INITIAL_STATE=initial_state is not None, + num_warps=num_warps, + num_stages=num_stages + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + return dq, dk, dv, None, None, None + + +def fused_recurrent_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + normalize: bool = False, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + if not head_first: + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + o, final_state = FusedRecurrentLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state) + if normalize: + o = normalize_output(q * scale, k, o) + if not head_first: + o = o.transpose(1, 2) + return o, final_state diff --git a/fla/ops/linear_attn/naive.py b/fla/ops/linear_attn/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ecf2718fcac8eef80f445ed02b95f36329f3c4 --- /dev/null +++ b/fla/ops/linear_attn/naive.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- + +from typing import Optional, Tuple + +import torch +from einops import rearrange + +from fla.ops.linear_attn.utils import normalize_output + + +def naive_chunk_linear_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: Optional[float] = None, + normalize: bool = False +) -> Tuple[torch.Tensor, torch.Tensor]: + if scale is None: + scale = q.shape[-1] ** -0.5 + chunk_size = 64 + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + kv = k.transpose(-1, -2) @ v + kv = kv.cumsum(2) + kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) + inter = q @ kv + intra = (( + q @ k.transpose(-1, -2)).masked_fill_( + torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), + 0 + )) @ v + o = inter + intra + if normalize: + o = normalize_output(q * scale, k, o) + return rearrange(o, 'b h n c d -> b h (n c) d') diff --git a/fla/ops/linear_attn/utils.py b/fla/ops/linear_attn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b444376833f5d512af6fc2db387db75a43a92e5d --- /dev/null +++ b/fla/ops/linear_attn/utils.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- + +import torch + + +@torch.jit.script +def normalize_output(q, k, o): + k = k.cumsum(-2) + z = (q * k).sum(-1, keepdim=True) + return o / (z + 1e-10) diff --git a/fla/ops/nsa/__init__.py b/fla/ops/nsa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..941a1be41e1650961af0d28e64837421826ffab2 --- /dev/null +++ b/fla/ops/nsa/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +from .naive import naive_nsa +from .parallel import parallel_nsa + +__all__ = [ + 'naive_nsa', + 'parallel_nsa' +] diff --git a/fla/ops/nsa/__pycache__/__init__.cpython-311.pyc b/fla/ops/nsa/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5c0336715db3f7b15dfb771597ab5c92f3f50cb Binary files /dev/null and b/fla/ops/nsa/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/nsa/__pycache__/naive.cpython-311.pyc b/fla/ops/nsa/__pycache__/naive.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1afa3e040856ef8be1aebcb061be24d97d24f2d Binary files /dev/null and b/fla/ops/nsa/__pycache__/naive.cpython-311.pyc differ diff --git a/fla/ops/nsa/__pycache__/parallel.cpython-311.pyc b/fla/ops/nsa/__pycache__/parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1687114644d1acc3bd506403e9a2b1e44482b765 Binary files /dev/null and b/fla/ops/nsa/__pycache__/parallel.cpython-311.pyc differ diff --git a/fla/ops/nsa/__pycache__/utils.cpython-311.pyc b/fla/ops/nsa/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fdc69ad9287abf61775b6f5f5eb27f5d1a3d4cd Binary files /dev/null and b/fla/ops/nsa/__pycache__/utils.cpython-311.pyc differ diff --git a/fla/ops/nsa/naive.py b/fla/ops/nsa/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..79433b80ef024874816f3b8e4fff47fe93a2578d --- /dev/null +++ b/fla/ops/nsa/naive.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional + +import torch +from einops import rearrange, repeat + + +def naive_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + indices: torch.LongTensor, + block_size: int = 64, + scale: Optional[float] = None, + head_first: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, HQ, T, K]` if `head_first=True` else `[B, T, HQ, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=True` else `[B, T, H, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + block_size (int): + Selected block size. Default: 64. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, HQ, T, V]` if `head_first=True` else `[B, T, HQ, V]`. + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if head_first: + q, k, v, indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v, indices)) + + dtype = q.dtype + G = q.shape[2] // k.shape[2] + BS = block_size + k, v, indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, indices)) + q, k, v = map(lambda x: x.float(), (q, k, v)) + + o = torch.zeros_like(v) + varlen = True + if cu_seqlens is None: + varlen = False + B, T = q.shape[:2] + cu_seqlens = torch.cat([indices.new_tensor(range(0, B*T, T)), indices.new_tensor([B*T])]) + + for i in range(len(cu_seqlens) - 1): + if not varlen: + q_b, k_b, v_b, i_b = q[i], k[i], v[i], indices[i] + else: + T = cu_seqlens[i+1] - cu_seqlens[i] + q_b, k_b, v_b, i_b = map(lambda x: x[0][cu_seqlens[i]:cu_seqlens[i+1]], (q, k, v, indices)) + + i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS)) + # [T, S*BS, HQ] + i_b = i_b.view(T, indices.shape[2], -1).transpose(1, 2) + for i_q in range(T): + # [HQ, D] + q_i = q_b[i_q] * scale + # [S*BS, HQ] + i_i = i_b[i_q] + # [S*BS, HQ, -1] + k_i, v_i = map(lambda x: x.gather(0, i_i.clamp(0, T-1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b)) + # [S*BS, HQ] + attn = torch.einsum('h d, n h d -> n h', q_i, k_i).masked_fill(i_i > i_q, float('-inf')).softmax(0) + if not varlen: + o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + else: + o[0][cu_seqlens[i]+i_q] = torch.einsum('n h, n h v -> h v', attn, v_i) + + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o.to(dtype) diff --git a/fla/ops/nsa/parallel.py b/fla/ops/nsa/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..7e89d964c7357ceeabaaeb9500849ce6cbdecfad --- /dev/null +++ b/fla/ops/nsa/parallel.py @@ -0,0 +1,1435 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import warnings +from typing import Optional, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange + +from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices +from fla.ops.nsa.utils import _bitonic_merge +from fla.ops.utils import mean_pooling +from fla.ops.utils.op import exp, log +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func +except ImportError: + warnings.warn( + "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", + category=ImportWarning + ) + flash_attn_func = None + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit +def parallel_nsa_compression_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + offsets, + token_indices, + chunk_offsets, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + + p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + # [G, BV] + b_o = tl.zeros([G, BV], dtype=tl.float32) + # max scores for the current block + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + # lse = log(acc) + m + b_acc = tl.zeros([G], dtype=tl.float32) + + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BC] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + + # [G, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + if NC == 0: + b_lse = tl.zeros([G], dtype=tl.float32) + else: + b_o = b_o / b_acc[:, None] + b_lse = b_m + log(b_acc) + + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + if i_v == 0: + tl.store(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G), b_lse.to(lse.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_compression_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + offsets, + token_indices, + chunk_offsets, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + q += (bos + i_t) * HQ*K + do += (bos + i_t) * HQ*V + lse += (bos + i_t) * HQ + delta += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ*K + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + i_h * G + tl.arange(0, G) + p_delta = delta + i_h * G + tl.arange(0, G) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + + # [G, BK] + b_dq = tl.zeros([G, BK], dtype=tl.float32) + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BC] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where((o_c < NC)[None, :], b_p, 0) + + # [G, BV] @ [BV, BC] -> [G, BC] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [G, BC] @ [BC, BK] -> [G, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + b_dq *= scale + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_compression_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + offsets, + chunk_indices, + chunk_offsets, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) + + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BC, BK], dtype=tl.float32) + # [BC, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BC, BV], dtype=tl.float32) + + for i in range(i_c * BC * BS, T): + o_c = i_c * BC + tl.arange(0, BC) + + p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + # [BC, G] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where((i >= max(0, (o_c + 1) * BS - 1))[:, None], b_p, 0) + # [BC, G] @ [G, BV] -> [BC, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BC, BV] @ [BV, G] -> [BC, G] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BC, G] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BC, G] @ [G, BK] -> [BC, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK'], +) +@triton.jit +def parallel_nsa_kernel_topk( + q, + k, + lse, + scale, + block_indices, + offsets, + token_indices, + chunk_offsets, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + S: tl.constexpr, + BC: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + USE_OFFSETS: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + boc = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + boc = i_b * tl.cdiv(T, BS) + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + # the number of compression representations in total + TC = tl.cdiv(T, BS) + # the number of compression representations required to iterate over + # incomplete compression blocks are not included + NC = (i_t + 1) // BS + ################################ + # 1. lse computation + ################################ + if lse is not None: + b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)) + else: + # max scores for the current block + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + # lse = log(acc) + m + b_acc = tl.zeros([G], dtype=tl.float32) + for i_c in range(0, NC, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BC] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + + b_mp = b_m + if NC == 0: + b_lse = tl.zeros([G], dtype=tl.float32) + else: + b_lse = b_m + log(b_acc) + + ################################ + # 2. topk selection + ################################ + # [BC] + b_i = tl.full([BC], -1, dtype=tl.float32) + o_i = tl.zeros([BC], dtype=tl.int32) + m_i = tl.arange(0, BC) < BC//2 + for i_c in range(0, i_t // BS + 1, BC): + o_c = i_c + tl.arange(0, BC) + + p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) + # [BK, BC] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [G, BC] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf')) + # [G, BC] + b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), exp(b_s - b_lse[:, None])) + # the importance scores of the current block + # [BC] + b_i, b_ip = tl.sum(b_p, 0), b_i + o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i + + n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0]) + for i in tl.static_range(1, n_dims): + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims) + + if i_c != 0: + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims) + b_i_new = b_ip * m_i + b_i * (1 - m_i) + o_i_new = o_ip * m_i + o_i * (1 - m_i) + b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims) + else: + b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims) + + m_top = tl.arange(0, BC//S) == 0 + b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0) + + p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,)) + tl.store(p_b, b_top.to(p_b.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit +def parallel_nsa_fwd_kernel( + q, + k, + v, + o, + lse, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + block_indices += (bos + i_t) * H*S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # the Q block is kept in the shared memory throughout the whole kernel + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_o = tl.zeros([G, BV], dtype=tl.float32) + + b_m = tl.full([G], float('-inf'), dtype=tl.float32) + b_acc = tl.zeros([G], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [G, BS] + b_s = tl.dot(b_q, b_k) + b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) + + # [G] + b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m + b_r = exp(b_mp - b_m) + # [G, BS] + b_p = exp(b_s - b_m[:, None]) + # [G] + b_acc = b_acc * b_r + tl.sum(b_p, 1) + # [G, BV] + b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) + + b_mp = b_m + b_o = b_o / b_acc[:, None] + b_m += log(b_acc) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) +}) +@triton.jit +def parallel_nsa_kernel_mask( + block_indices, + block_counts, + block_mask, + T: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + NS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_h, i_s = i_hs // S, i_hs % S + + b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s) + if USE_BLOCK_COUNTS: + b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h) + else: + b_m = b_i * BS <= i_t + + if b_i < NS and b_i >= 0: + tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) + + +@triton.jit +def parallel_nsa_bwd_kernel_preprocess( + o, + do, + delta, + B: tl.constexpr, + V: tl.constexpr +): + i_n = tl.program_id(0) + o_d = tl.arange(0, B) + m_d = o_d < V + + b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) + b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) + b_delta = tl.sum(b_o * b_do) + + tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_bwd_kernel_dq( + q, + k, + v, + lse, + delta, + do, + dq, + scale, + block_indices, + block_counts, + offsets, + token_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + S: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_BLOCK_COUNTS: tl.constexpr +): + i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (bos + i_t) * HQ*K + do += (bos + i_t) * HQ*V + lse += (bos + i_t) * HQ + delta += (bos + i_t) * HQ + dq += (i_v * B * T + bos + i_t) * HQ*K + block_indices += (bos + i_t) * H*S + i_h * S + + if USE_BLOCK_COUNTS: + NS = tl.load(block_counts + (bos + i_t) * H + i_h) + else: + NS = S + + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + + p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + i_h * G + tl.arange(0, G) + p_delta = delta + i_h * G + tl.arange(0, G) + + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + + # [G, BK] + b_dq = tl.zeros([G, BK], dtype=tl.float32) + for i in range(NS): + i_s = tl.load(block_indices + i).to(tl.int32) * BS + if i_s <= i_t and i_s >= 0: + p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # [G, BS] + b_s = tl.dot(b_q, b_k) + b_p = exp(b_s - b_lse[:, None]) + b_p = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p, 0) + + # [G, BV] @ [BV, BS] -> [G, BS] + b_dp = tl.dot(b_do, b_v) + b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) + # [G, BS] @ [BS, BK] -> [G, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) + b_dq *= scale + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics({ + 'USE_OFFSETS': lambda args: args['offsets'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps) + for num_warps in [1, 2, 4] + ], + key=['BS', 'BK', 'BV'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_nsa_bwd_kernel_dkv( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + block_mask, + offsets, + chunk_indices, + scale, + T, + B: tl.constexpr, + H: tl.constexpr, + HQ: tl.constexpr, + G: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + M: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_OFFSETS: tl.constexpr +): + i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if USE_OFFSETS: + i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) + + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BS, BK], dtype=tl.float32) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BS, BV], dtype=tl.float32) + + for i in range(i_s * BS, T): + b_m = tl.load(block_mask + (bos + i) * H*M + i_h * M + i_s) + if b_m: + p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) + # [G, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + + p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) + p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) + p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) + # [G, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [G] + b_lse = tl.load(p_lse) + b_delta = tl.load(p_delta) + # [BS, G] + b_s = tl.dot(b_k, tl.trans(b_q)) + b_p = exp(b_s - b_lse[None, :]) + b_p = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p, 0) + # [BS, G] @ [G, BV] -> [BS, BV] + b_dv += tl.dot(b_p.to(b_do.dtype), b_do) + # [BS, BV] @ [BV, G] -> [BS, G] + b_dp = tl.dot(b_v, tl.trans(b_do)) + # [BS, G] + b_ds = b_p * (b_dp - b_delta[None, :]) + # [BS, G] @ [G, BK] -> [BS, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def parallel_nsa_compression_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, HQ, K, V = *q.shape, v.shape[-1] + H = k.shape[2] + G = HQ // H + BC = BS = block_size + if check_shared_mem('hopper', q.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None + + grid = (T, NV, B * H) + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_nsa_compression_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + offsets=offsets, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_nsa_compression_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, HQ, K, V = *q.shape, v.shape[-1] + H = k.shape[2] + G = HQ // H + BC = BS = block_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + if offsets is not None: + lens = prepare_lens(offsets) + chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()]) + chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets) + chunk_offsets = prepare_chunk_offsets(offsets, BS) + NC = len(chunk_indices) + else: + chunk_indices, chunk_offsets = None, None + NC = triton.cdiv(triton.cdiv(T, BS), BC) + + delta = parallel_nsa_bwd_preprocess(o, do) + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_compression_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + scale=scale, + offsets=offsets, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV + ) + dq = dq.sum(0) + + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NC, B * H) + parallel_nsa_compression_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + offsets=offsets, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + BC=BC, + BS=BS, + BK=BK, + BV=BV + ) + dk = dk.sum(0) + return dq, dk, dv + + +class ParallelNSACompressionFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + block_size, + scale, + offsets + ): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_compression_fwd( + q=q, + k=k, + v=v, + block_size=block_size, + scale=scale, + offsets=offsets, + token_indices=token_indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype), lse + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do, *args): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_nsa_compression_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + block_size=ctx.block_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None + + +def parallel_nsa_topk( + q: torch.Tensor, + k: torch.Tensor, + lse: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, +) -> torch.LongTensor: + B, T, HQ, K = q.shape + H = k.shape[2] + G = HQ // H + S = block_counts if isinstance(block_counts, int) else block_counts.max().item() + S = triton.next_power_of_2(S) + # here we set BC = BS, but beware that they are actually decoupled + BC = BS = block_size + BK = triton.next_power_of_2(K) + + block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device) + token_indices = prepare_token_indices(offsets) if offsets is not None else None + chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None + grid = (T, B * H) + parallel_nsa_kernel_topk[grid]( + q=q, + k=k, + lse=lse, + scale=scale, + block_indices=block_indices, + offsets=offsets, + token_indices=token_indices, + chunk_offsets=chunk_offsets, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + S=S, + BC=BC, + BS=BS, + BK=BK + ) + return block_indices + + +def parallel_nsa_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + block_size: int, + scale: float, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + if check_shared_mem('hopper', q.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + else: + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert NK == 1, "The key dimension can not be larger than 256" + + grid = (T, NV, B * H) + o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) + lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) + + parallel_nsa_fwd_kernel[grid]( + q=q, + k=k, + v=v, + o=o, + lse=lse, + scale=scale, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + T=T, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + BK=BK, + BV=BV, + ) + return o, lse + + +def parallel_nsa_block_mask( + block_indices: torch.LongTensor, + block_counts: Union[torch.LongTensor, int], + offsets: torch.LongTensor, + block_size: int, +): + B, T, H, S = block_indices.shape + BS = block_size + if offsets is not None: + NS = triton.cdiv(prepare_lens(offsets).max().item(), BS) + else: + NS = triton.cdiv(T, BS) + block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) + + parallel_nsa_kernel_mask[(T, B, H*S)]( + block_indices=block_indices, + block_counts=block_counts, + block_mask=block_mask, + T=T, + H=H, + S=S, + BS=BS, + NS=NS + ) + return block_mask + + +def parallel_nsa_bwd_preprocess( + o: torch.Tensor, + do: torch.Tensor +): + V = o.shape[-1] + delta = torch.empty_like(o[..., 0], dtype=torch.float32) + parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)]( + o=o, + do=do, + delta=delta, + B=triton.next_power_of_2(V), + V=V, + ) + return delta + + +def parallel_nsa_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + lse: torch.Tensor, + do: torch.Tensor, + block_indices: torch.Tensor, + block_counts: Union[torch.LongTensor, int], + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None, + token_indices: Optional[torch.LongTensor] = None, +): + B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] + HQ = q.shape[2] + G = HQ // H + BS = block_size + BK = triton.next_power_of_2(K) + BV = min(128, triton.next_power_of_2(v.shape[-1])) + NV = triton.cdiv(V, BV) + + delta = parallel_nsa_bwd_preprocess(o, do) + + dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + grid = (T, NV, B * H) + parallel_nsa_bwd_kernel_dq[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dq=dq, + block_indices=block_indices, + block_counts=block_counts, + offsets=offsets, + token_indices=token_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + S=S, + BS=BS, + BK=BK, + BV=BV + ) + dq = dq.sum(0) + + if offsets is not None: + chunk_indices = prepare_chunk_indices(offsets, BS) + NS = len(chunk_indices) + else: + chunk_indices = None + NS = triton.cdiv(T, BS) + + # [B, T, H, M] + block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size) + dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) + + grid = (NV, NS, B * H) + parallel_nsa_bwd_kernel_dkv[grid]( + q=q, + k=k, + v=v, + lse=lse, + delta=delta, + do=do, + dk=dk, + dv=dv, + block_mask=block_mask, + offsets=offsets, + chunk_indices=chunk_indices, + scale=scale, + T=T, + B=B, + H=H, + HQ=HQ, + G=G, + K=K, + V=V, + M=block_mask.shape[-1], + BS=BS, + BK=BK, + BV=BV + ) + dk = dk.sum(0) + return dq, dk, dv + + +@torch.compile +class ParallelNSAFunction(torch.autograd.Function): + + @staticmethod + @contiguous + @autocast_custom_fwd + def forward(ctx, q, k, v, block_indices, block_counts, block_size, scale, offsets): + ctx.dtype = q.dtype + + # 2-d sequence indices denoting the offsets of tokens in each sequence + # for example, if the passed `offsets` is [0, 2, 6], + # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + token_indices = prepare_token_indices(offsets) if offsets is not None else None + + o, lse = parallel_nsa_fwd( + q=q, + k=k, + v=v, + block_indices=block_indices, + block_counts=block_counts, + block_size=block_size, + scale=scale, + offsets=offsets, + token_indices=token_indices + ) + ctx.save_for_backward(q, k, v, o, lse) + ctx.block_indices = block_indices + ctx.block_counts = block_counts + ctx.offsets = offsets + ctx.token_indices = token_indices + ctx.block_size = block_size + ctx.scale = scale + return o.to(q.dtype) + + @staticmethod + @contiguous + @autocast_custom_bwd + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq, dk, dv = parallel_nsa_bwd( + q=q, + k=k, + v=v, + o=o, + lse=lse, + do=do, + block_indices=ctx.block_indices, + block_counts=ctx.block_counts, + block_size=ctx.block_size, + scale=ctx.scale, + offsets=ctx.offsets, + token_indices=ctx.token_indices + ) + return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None + + +def parallel_nsa_compression( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_size: int = 64, + scale: float = None, + offsets: Optional[torch.LongTensor] = None +): + return ParallelNSACompressionFunction.apply( + q, + k, + v, + block_size, + scale, + offsets + ) + + +def parallel_nsa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g_cmp: torch.Tensor, + g_slc: torch.Tensor, + g_swa: torch.Tensor, + block_indices: Optional[torch.LongTensor] = None, + block_counts: Union[torch.LongTensor, int] = 16, + block_size: int = 64, + window_size: int = 0, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False +) -> torch.Tensor: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g_cmp (torch.Tensor): + Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_slc (torch.Tensor): + Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + g_swa (torch.Tensor): + Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. + block_indices (torch.LongTensor): + Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. + `S` is the number of selected blocks for each query token, which is set to 16 in the paper. + If `g_cmp` is provided, the passed `block_indices` will be ignored. + block_counts (Optional[Union[torch.LongTensor, int]]): + Number of selected blocks for each query. + If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`, + each query can select the same number of blocks. + If not provided, it will default to 16. + block_size (int): + Selected block size. Default: 64. + window_size (int): + Sliding window size. Default: 0. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. + """ + assert block_counts is not None, "block counts must be provided for selection" + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + if head_first: + q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) + g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa)) + if not isinstance(block_counts, int): + block_counts = rearrange(block_counts, 'b h t -> b t h') + assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" + + k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens) + o_cmp, lse_cmp = None, None + if g_cmp is not None: + o_cmp, lse_cmp = parallel_nsa_compression( + q=q, + k=k_cmp, + v=v_cmp, + block_size=block_size, + scale=scale, + offsets=cu_seqlens + ) + if block_indices is not None: + warnings.warn("`block_indices` will be ignored when `g_cmp` is provided") + block_indices = parallel_nsa_topk( + q=q, + k=k_cmp, + lse=lse_cmp, + block_counts=block_counts, + block_size=block_size, + scale=scale, + offsets=cu_seqlens + ) + o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens) + o = o_slc * g_slc.unsqueeze(-1) + if o_cmp is not None: + o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1)) + if window_size > 0: + if cu_seqlens is not None: + max_seqlen = q.shape[1] + o_swa = flash_attn_varlen_func( + q.squeeze(0), k.squeeze(0), v.squeeze(0), + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + causal=True, + window_size=(window_size-1, 0) + ).unsqueeze(0) + else: + o_swa = flash_attn_func( + q, k, v, + causal=True, + window_size=(window_size-1, 0) + ) + o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1)) + if head_first: + o = rearrange(o, 'b t h d -> b h t d') + return o diff --git a/fla/ops/nsa/utils.py b/fla/ops/nsa/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73e54138b750a280c4f8edd04ca36ffb3f58705f --- /dev/null +++ b/fla/ops/nsa/utils.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Implements argsort based on bitonic sort. +# [What is bitonic sort?](https://en.wikipedia.org/wiki/Bitonic_sorter) + +# Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 + + +import triton +import triton.language as tl + +from fla.ops.utils.op import log2 + + +@triton.jit +def _compare_and_swap( + x, + ids, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] + y = tl.reshape(x, shape) + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = tl.arange(0, 2)[None, :, None] + left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) + right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype) + left = tl.reshape(left, x.shape) + right = tl.reshape(right, x.shape) + # idx + y_idx = tl.reshape(ids, shape) + left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype) + right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype) + # actual compare-and-swap + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) != flip + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids)) + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge( + x, + ids, + stage: tl.constexpr, + order: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + if order == 2: + shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = order + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def argsort( + x, + ids, + dim: tl.constexpr = None, + descending: tl.constexpr = tl.core.CONSTEXPR_0, +): + # handle default dimension or check that it is the most minor dim + _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim + tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") + # iteratively run bitonic merge-sort steps + n_dims: tl.constexpr = log2(x.shape[_dim]) + + for i in tl.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + return x, ids diff --git a/fla/ops/rwkv4/fused_recurrent.py b/fla/ops/rwkv4/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..63a5c6577dd3ef288aa59c494e74b8d29d8580ad --- /dev/null +++ b/fla/ops/rwkv4/fused_recurrent.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Any, cast + +import torch +import triton +import triton.language as tl +from torch import Tensor +from torch.autograd.function import Function, FunctionCtx, once_differentiable + +from fla.ops.utils.op import exp + + +def get_block_size_c(chans: int) -> int: + if chans < 32: + return 32 + if chans < 64: + return 64 + return 128 + + +@triton.jit +def fused_recurrent_rwkv4_forward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_c, + # WKV + wkv_ptr, + wkv_s_b, + wkv_s_t, + wkv_s_c, + # Output state + state_out_ptr, + state_out_s_b, + state_out_s_abe, + state_out_s_t, + state_out_s_c, + # Params + chans, + tsz, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + wkv_ptr = wkv_ptr + b_idx * wkv_s_b + alpha_out_ptr = state_out_ptr + b_idx * state_out_s_b + beta_out_ptr = state_out_ptr + b_idx * state_out_s_b + state_out_s_abe + eps_out_ptr = state_out_ptr + b_idx * state_out_s_b + 2 * state_out_s_abe + + # Loads parameters. + alpha = tl.load(alpha_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + beta = tl.load(beta_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + eps = tl.load(eps_ptr + cs * state_s_c, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + cs * w_s_c, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + cs * u_s_c, mask=cmask).to(tl.float32) + + for t in range(tsz): + kt = tl.load(k_ptr + t * k_s_t + cs * k_s_c, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + t * v_s_t + cs * v_s_c, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps) + e1a = exp(eps - tau) + e2a = exp(ukt - tau) + wkv = (e1a * alpha + e2a * vt) / (e1a * beta + e2a) + tl.store(wkv_ptr + t * wkv_s_t + cs * wkv_s_c, wkv, mask=cmask) + + w_eps = w + eps + eps = tl.maximum(w_eps, kt) + e1b = exp(w_eps - eps) + e2b = exp(kt - eps) + alpha = e1b * alpha + e2b * vt + beta = e1b * beta + e2b + tl.store(alpha_out_ptr + t * state_out_s_t + cs * state_out_s_c, alpha, mask=cmask) + tl.store(beta_out_ptr + t * state_out_s_t + cs * state_out_s_c, beta, mask=cmask) + tl.store(eps_out_ptr + t * state_out_s_t + cs * state_out_s_c, eps, mask=cmask) + + +def fused_recurrent_rwkv4_forward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, +) -> tuple[Tensor, Tensor]: + (bsz, tsz, chans) = k.shape + + # New tensors to output. + wkvs = k.new_empty(bsz, tsz, chans) + state_out = k.new_empty(bsz, 3, tsz, chans) + + # Constants. + block_size_c = get_block_size_c(chans) + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_forward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(3), + # WKV + wkvs, + wkvs.stride(0), + wkvs.stride(1), + wkvs.stride(2), + # Output state + state_out, + state_out.stride(0), + state_out.stride(1), + state_out.stride(2), + state_out.stride(3), + # Params + chans, + tsz, + BLOCK_SIZE_C=block_size_c, + ) + + state_out = torch.cat((state, state_out), dim=2) + + return wkvs, state_out + + +@triton.jit +def fused_recurrent_rwkv4_backward_kernel( + # W + w_ptr, + w_s_c, + # U + u_ptr, + u_s_c, + # K + k_ptr, + k_s_b, + k_s_t, + k_s_c, + # V + v_ptr, + v_s_b, + v_s_t, + v_s_c, + # State + state_ptr, + state_s_b, + state_s_abe, + state_s_t, + state_s_c, + # WKV grad + gwkv_ptr, + gwkv_s_b, + gwkv_s_t, + gwkv_s_c, + # Output state grad + gstate_out_ptr, + gstate_out_s_b, + gstate_out_s_abe, + gstate_out_s_c, + # W grad + gw_ptr, + gw_s_c, + # U grad + gu_ptr, + gu_s_c, + # K grad + gk_ptr, + gk_s_b, + gk_s_t, + gk_s_c, + # V grad + gv_ptr, + gv_s_b, + gv_s_t, + gv_s_c, + # State grad + gstate_ptr, + gstate_s_b, + gstate_s_abe, + gstate_s_c, + # Params + tsz, + chans, + BLOCK_SIZE_C: tl.constexpr, +): + # Parallelize over the batch dimension. + b_idx = tl.program_id(0) + c_idx = tl.program_id(1) + + cs = (c_idx * BLOCK_SIZE_C) + tl.arange(0, BLOCK_SIZE_C) + cmask = cs < chans + + # Pointers to the batch (and possibly channel) for the input tensors. + k_ptr = k_ptr + b_idx * k_s_b + v_ptr = v_ptr + b_idx * v_s_b + alpha_ptr = state_ptr + b_idx * state_s_b + beta_ptr = state_ptr + b_idx * state_s_b + state_s_abe + eps_ptr = state_ptr + b_idx * state_s_b + 2 * state_s_abe + + # Pointers to the batch (and possibly channel) for the output tensors. + gk_ptr = gk_ptr + b_idx * gk_s_b + gv_ptr = gv_ptr + b_idx * gv_s_b + + # Pointers to gradients which were recieved by the function. + gwkv_ptr = gwkv_ptr + b_idx * gwkv_s_b + galpha_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gbeta_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + gstate_out_s_abe + geps_out_ptr = gstate_out_ptr + b_idx * gstate_out_s_b + 2 * gstate_out_s_abe + + # Loads parameters. + galpha = tl.load(galpha_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + gbeta = tl.load(gbeta_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + geps = tl.load(geps_out_ptr + gstate_out_s_c * cs, mask=cmask).to(tl.float32) + w = tl.load(w_ptr + w_s_c * cs, mask=cmask).to(tl.float32) + u = tl.load(u_ptr + u_s_c * cs, mask=cmask).to(tl.float32) + + # Gradient accumulators. + gw = tl.zeros_like(w) + gu = tl.zeros_like(u) + + alpha_prev = tl.load(alpha_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tsz * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + for t in range(tsz): + tc = tsz - t - 1 + + kt = tl.load(k_ptr + tc * k_s_t + k_s_c * cs, mask=cmask).to(tl.float32) + vt = tl.load(v_ptr + tc * v_s_t + v_s_c * cs, mask=cmask).to(tl.float32) + + alpha_curr = alpha_prev + beta_curr = beta_prev + eps_curr = eps_prev + + alpha_prev = tl.load(alpha_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + beta_prev = tl.load(beta_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + eps_prev = tl.load(eps_ptr + tc * state_s_t + state_s_c * cs, mask=cmask).to(tl.float32) + + ukt = u + kt + tau = tl.maximum(ukt, eps_prev) + e1 = exp(eps_prev - tau) + e2 = exp(ukt - tau) + + euke = exp(ukt + eps_prev - 2 * tau) + + denom = e1 * beta_prev + e2 + denom_sq = denom * denom + + gwkvt = tl.load(gwkv_ptr + tc * gwkv_s_t + gwkv_s_c * cs, mask=cmask).to(tl.float32) + + # Backpropagates wkv gradients. + guk = gwkvt * e2 * (e1 * beta_prev * vt - e1 * alpha_prev) / denom_sq + gu += guk + gk = guk + gv = gwkvt * e2 / denom + + galpha_wkv = gwkvt * e1 / denom + gbeta_wkv = -gwkvt * e1 * (e2 * vt + e1 * alpha_prev) / denom_sq + geps_wkv_denom = e1 * beta_prev + e2 + geps_wkv = gwkvt * euke * (alpha_prev - vt * beta_prev) / (geps_wkv_denom * geps_wkv_denom) + + e1 = exp(w + eps_prev - eps_curr) + e2 = exp(kt - eps_curr) + + # Backpropagates alpha gradients. + galpha_we = galpha * e1 * alpha_prev + gw += galpha_we + gk += galpha * e2 * vt + gv += galpha * e2 + geps += galpha * -alpha_curr + + # Backpropagates beta gradients. + gbeta_we = gbeta * e1 * beta_prev + gw += gbeta_we + gk += gbeta * e2 + geps += gbeta * -beta_curr + + # Backpropagates epsilon gradients. + geps_mask = w + eps_prev > kt + geps_we = tl.where(geps_mask, geps, tl.zeros_like(geps)) + gw += geps_we + gk += tl.where(geps_mask, tl.zeros_like(geps), geps) + + # Stores the gradients for k and v. + tl.store(gk_ptr + tc * gk_s_t + gk_s_c * cs, gk, mask=cmask) + tl.store(gv_ptr + tc * gv_s_t + gv_s_c * cs, gv, mask=cmask) + + # Computes new gradients for alpha and beta. + galpha = galpha * e1 + galpha_wkv + gbeta = gbeta * e1 + gbeta_wkv + geps = galpha_we + gbeta_we + geps_we + geps_wkv + + # Stores final gradients for alpha and beta. + galpha_ptr = gstate_ptr + b_idx * gstate_s_b + gbeta_ptr = gstate_ptr + b_idx * gstate_s_b + gstate_s_abe + geps_ptr = gstate_ptr + b_idx * gstate_s_b + 2 * gstate_s_abe + tl.store(galpha_ptr + gstate_s_c * cs, galpha, mask=cmask) + tl.store(gbeta_ptr + gstate_s_c * cs, gbeta, mask=cmask) + tl.store(geps_ptr + gstate_s_c * cs, geps, mask=cmask) + + # Stores final gradients for w and u. + gw_temp = tl.load(gw_ptr + gw_s_c * cs, mask=cmask).to(tl.float32) + gw_temp += gw + tl.store(gw_ptr + gw_s_c * cs, gw_temp, mask=cmask) + gu_temp = tl.load(gu_ptr + gu_s_c * cs, mask=cmask).to(tl.float32) + gu_temp += gu + tl.store(gu_ptr + gu_s_c * cs, gu_temp, mask=cmask) + + +def fused_recurrent_rwkv4_backward( + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + grad_wkv: Tensor, + grad_state: Tensor, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + bsz, tsz, chans = k.shape + + gw = torch.zeros_like(w) # New tensors to output. + gu = torch.zeros_like(u) + gk = torch.empty_like(k) + gv = torch.empty_like(v) + gstate = k.new_empty(bsz, 3, 1, chans) + + block_size_c = get_block_size_c(chans) # Constants. + + def grid(meta: dict[str, Any]) -> tuple[int, ...]: + return (bsz, triton.cdiv(chans, meta["BLOCK_SIZE_C"])) + + fused_recurrent_rwkv4_backward_kernel[grid]( + # W + w, + w.stride(0), + # U + u, + u.stride(0), + # K + k, + k.stride(0), + k.stride(1), + k.stride(2), + # V + v, + v.stride(0), + v.stride(1), + v.stride(2), + # State + state, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + # WKV grad + grad_wkv, + grad_wkv.stride(0), + grad_wkv.stride(1), + grad_wkv.stride(2), + # Output state grad + grad_state, + grad_state.stride(0), + grad_state.stride(1), + grad_state.stride(3), + # W grad + gw, + gw.stride(0), + # U grad + gu, + gu.stride(0), + # K grad + gk, + gk.stride(0), + gk.stride(1), + gk.stride(2), + # V grad + gv, + gv.stride(0), + gv.stride(1), + gv.stride(2), + # State grad + gstate, + gstate.stride(0), + gstate.stride(1), + gstate.stride(3), + # Params + tsz, + chans, + BLOCK_SIZE_C=block_size_c, + ) + + return gw, gu, gk, gv, gstate + + +class FusedRecurrentRWKV4Function(Function): + @staticmethod + def forward( + ctx: FunctionCtx, + w: Tensor, + u: Tensor, + k: Tensor, + v: Tensor, + state: Tensor, + ) -> tuple[Tensor, Tensor]: + ctx.input_dtype = k.dtype + + w = -torch.exp(w.float().contiguous()) + if k.dtype == torch.float16: + u = u.float() + k = k.float() + v = v.float() + u = u.contiguous() + k = k.contiguous() + v = v.contiguous() + wkv, state_out = fused_recurrent_rwkv4_forward(w, u, k, v, state) + ctx.save_for_backward(w, u, k, v, state_out[:, :, :-1]) + return wkv, state_out[:, :, -1:] + + @staticmethod + @once_differentiable + def backward(ctx: FunctionCtx, gwkv: Tensor, gstate: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + w, u, k, v, state = cast(tuple[Tensor, ...], ctx.saved_tensors) + gw, gu, gk, gv, gstate = fused_recurrent_rwkv4_backward(w, u, k, v, state, gwkv, gstate) + return gw, gu, gk, gv, gstate + + +def fused_recurrent_rwkv4(w: Tensor, u: Tensor, k: Tensor, v: Tensor, state: Tensor) -> tuple[Tensor, Tensor]: + return FusedRecurrentRWKV4Function.apply(w, u, k, v, state) diff --git a/fla/ops/simple_gla/README.md b/fla/ops/simple_gla/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2a64f3dcdee7ff9863089a6b47ef694f6234ab8f --- /dev/null +++ b/fla/ops/simple_gla/README.md @@ -0,0 +1,10 @@ +# Simple GLA + +Gating mechanism in [Gated RFA](https://arxiv.org/abs/2103.02143), [Mamba2](https://arxiv.org/abs/2405.21060) and [YOCO](https://arxiv.org/abs/2405.05254) (a.k.a., Gated RetNet). + +Compared to GLA, the gating is head-wise instead of elementwise. +As a result, we can adapt the RetNet kernel for training using matmul w/o numerical instability. +It is faster than GLA but has less expressive power. +I will use it as a baseline for the GLA. + +$S_{t+1} = g_{t+1} \odot S_{t} + K_{t+1} V_{t+1}^{\top}$ where $g$ is a scalar. diff --git a/fla/ops/simple_gla/__init__.py b/fla/ops/simple_gla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..561e3afbf81e8ab1b0fe738e5c5e8d1e1626868e --- /dev/null +++ b/fla/ops/simple_gla/__init__.py @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from .chunk import chunk_simple_gla +from .fused_recurrent import fused_recurrent_simple_gla +from .parallel import parallel_simple_gla + +__all__ = [ + 'chunk_simple_gla', + 'fused_recurrent_simple_gla', + 'parallel_simple_gla' +] diff --git a/fla/ops/simple_gla/__pycache__/__init__.cpython-311.pyc b/fla/ops/simple_gla/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eec2dba422cf8a058d584d4150bbe498f59cfc1 Binary files /dev/null and b/fla/ops/simple_gla/__pycache__/__init__.cpython-311.pyc differ diff --git a/fla/ops/simple_gla/__pycache__/chunk.cpython-311.pyc b/fla/ops/simple_gla/__pycache__/chunk.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93bc41b4178f12ee4a6ec771d4352f4de61c93e6 Binary files /dev/null and b/fla/ops/simple_gla/__pycache__/chunk.cpython-311.pyc differ diff --git a/fla/ops/simple_gla/__pycache__/fused_recurrent.cpython-311.pyc b/fla/ops/simple_gla/__pycache__/fused_recurrent.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0d68820e028e77b4efb50f468883fede5351c78 Binary files /dev/null and b/fla/ops/simple_gla/__pycache__/fused_recurrent.cpython-311.pyc differ diff --git a/fla/ops/simple_gla/__pycache__/parallel.cpython-311.pyc b/fla/ops/simple_gla/__pycache__/parallel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843db7c2e109622e7b30b30b80cf79e2e29967d3 Binary files /dev/null and b/fla/ops/simple_gla/__pycache__/parallel.cpython-311.pyc differ diff --git a/fla/ops/simple_gla/chunk.py b/fla/ops/simple_gla/chunk.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c13a54aafca67e45d81021a9c9112c1f896af1 --- /dev/null +++ b/fla/ops/simple_gla/chunk.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton + +from fla.ops.common.chunk_h import chunk_bwd_dh, chunk_fwd_h +from fla.ops.common.chunk_o import chunk_bwd_dqkwg, chunk_bwd_dv, chunk_fwd_o +from fla.ops.utils import chunk_local_cumsum +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard + + +def chunk_simple_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) if g is not None else None + h, ht = chunk_fwd_h( + k=k, + v=v, + g=g, + gk=None, + gv=None, + h0=initial_state, + output_final_state=output_final_state, + states_in_fp32=False, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v, + g=g, + h=h, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return g, o, ht + + +def chunk_simple_gla_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + initial_state: torch.Tensor, + do: torch.Tensor, + dht: torch.Tensor, + scale: float, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, + head_first: bool = True, + chunk_size: int = 64 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # (SY 09/22) states_in_fp32 seems not affecting the error of dg but for safety, set to True + h, _ = chunk_fwd_h( + k=k, + v=v, + g=g, + gk=None, + gv=None, + h0=initial_state, + output_final_state=False, + states_in_fp32=True, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size + ) + dh, dh0 = chunk_bwd_dh( + q=q, + k=k, + v=v, + g=g, + gk=None, + gv=None, + do=do, + h0=initial_state, + dht=dht, + scale=scale, + states_in_fp32=True, + offsets=offsets, + head_first=head_first, + chunk_size=chunk_size + ) + dq, dk, _, dg = chunk_bwd_dqkwg( + q=q, + k=k, + v=v, + g=g, + h=h, + do=do, + dh=dh, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + dv = chunk_bwd_dv( + q=q, + k=k, + g=g, + do=do, + dh=dh, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + return dq, dk, dv, dg, dh0 + + +class ChunkSimpleGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward( + ctx, + q, + k, + v, + g, + scale, + initial_state, + output_final_state, + offsets, + head_first + ): + T = q.shape[2] if head_first else q.shape[1] + chunk_size = min(64, max(16, triton.next_power_of_2(T))) + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + g, o, ht = chunk_simple_gla_fwd( + q=q, + k=k, + v=v, + g=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + ctx.save_for_backward(q, k, v, g, initial_state) + ctx.chunk_size = chunk_size + ctx.scale = scale + ctx.offsets = offsets + ctx.indices = indices + ctx.head_first = head_first + return o.to(q.dtype), ht + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, dht): + chunk_size, scale, offsets, indices, head_first = ctx.chunk_size, ctx.scale, ctx.offsets, ctx.indices, ctx.head_first + q, k, v, g, initial_state = ctx.saved_tensors + dq, dk, dv, dg, dh0 = chunk_simple_gla_bwd( + q=q, + k=k, + v=v, + g=g, + initial_state=initial_state, + do=do, + dht=dht, + scale=scale, + offsets=offsets, + indices=indices, + head_first=head_first, + chunk_size=chunk_size + ) + if g is not None: + dg = chunk_local_cumsum(dg, chunk_size, reverse=True, offsets=offsets, + indices=indices, head_first=head_first).to(g.dtype) + else: + dg = None + return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), dg, None, dh0, None, None, None + + +@torch.compiler.disable +def chunk_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, # log decay + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Compared to GLA, the gating is head-wise instead of elementwise. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.simple_gla import chunk_simple_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, device='cuda')) + >>> o, ht = chunk_simple_gla(q, k, v, g, + initial_state=None, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_simple_gla(q, k, v, g, + initial_state=None, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkSimpleGLAFunction.apply( + q, + k, + v, + g, + scale, + initial_state, + output_final_state, + cu_seqlens, + head_first + ) + return o, final_state diff --git a/fla/ops/simple_gla/fused_recurrent.py b/fla/ops/simple_gla/fused_recurrent.py new file mode 100644 index 0000000000000000000000000000000000000000..7012497d8554903ab712c351d38cbd116f6e9f0d --- /dev/null +++ b/fla/ops/simple_gla/fused_recurrent.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2024, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch + +from fla.ops.common.fused_recurrent import fused_recurrent + + +def fused_recurrent_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + reverse: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + g (torch.Tensor): + Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Compared to GLA, the gating is head-wise instead of elementwise. + scale (Optional[int]): + Scale factor for the attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + reverse (Optional[bool]): + If `True`, process the state passing in reverse order. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `True`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.simple_gla import fused_recurrent_simple_gla + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = torch.randn(B, T, H, K, device='cuda') + >>> v = torch.randn(B, T, H, V, device='cuda') + >>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, device='cuda') + >>> o, ht = fused_recurrent_simple_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + head_first=False) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_recurrent_simple_gla(q, k, v, g, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens, + head_first=False) + >>> assert o.allclose(o_var.view(o.shape)) + >>> assert ht.allclose(ht_var) + """ + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing.") + if head_first: + raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}.") + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = fused_recurrent( + q=q, + k=k, + v=v, + g=g, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + reverse=reverse, + cu_seqlens=cu_seqlens, + head_first=head_first + ) + return o, final_state diff --git a/fla/ops/simple_gla/naive.py b/fla/ops/simple_gla/naive.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcc96ebeb720cc8b9699793ee6bdf8d3d39fdaa --- /dev/null +++ b/fla/ops/simple_gla/naive.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +import torch +from einops import rearrange + + +def torch_simple_gla(q, k, v, g, chunk_size=64, scale=None): + if scale is None: + scale = (q.shape[-1] ** -0.5) + q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale + k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) + v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) + g = rearrange(g, 'b h (n c) -> b h n c', c=chunk_size) + g = g.cumsum(-1) + kv = k.transpose(-1, -2) @ (v * (-g + g[:, :, :, -1, None]).exp()[..., None]) + S = torch.zeros_like(kv) + + for i in range(1, g.shape[-2]): + S[:, :, i] = S[:, :, i-1].clone() * g[:, :, i-1, -1, None, None].exp() + kv[:, :, i-1] + + inter = (q * g[..., None].exp()) @ S + attn = q @ k.transpose(-1, -2) + attn = attn * (g[..., None] - g[..., None, :]).exp() + attn = attn.masked_fill(torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), 0) + intra = attn @ v + o = inter + intra + return rearrange(o, 'b h n c d -> b h (n c) d') + + +def torch_simple_gla_recurrent(q, k, v, g, scale=None, initial_state=None, output_final_state=True): + B, H, T, DK = q.shape + original_dtype = q.dtype + q, k, v, g = q.float(), k.float(), v.float(), g.float() + if scale is None: + scale = DK ** -0.5 + q = q * scale + _, _, _, DV = v.shape + if initial_state is None: + S = torch.zeros(B, H, DK, DV) + else: + S = initial_state + o = torch.zeros(B, H, T, DV).to(q) + for i in range(T): + gate = g[:, :, i].exp() + key = k[:, :, i] + value = v[:, :, i] + kv = key.unsqueeze(-1) * value.unsqueeze(-2) + S = S.clone() * gate.unsqueeze(-1).unsqueeze(-1) + kv + q_i = q[:, :, i, :] + o_i = (q_i.unsqueeze(-1) * S).sum(-2) + o[:, :, i] = o_i + if not output_final_state: + S = None + return o.to(original_dtype), S diff --git a/fla/ops/simple_gla/parallel.py b/fla/ops/simple_gla/parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..d0ad1c8c33dd846eb1d1cf3c582836b6110017d7 --- /dev/null +++ b/fla/ops/simple_gla/parallel.py @@ -0,0 +1,722 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from fla.ops.utils import chunk_global_cumsum, chunk_local_cumsum +from fla.ops.utils.op import safe_exp +from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, input_guard, is_intel_alchemist + +# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 +triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} + + +@triton.heuristics({ + 'NV': lambda args: triton.cdiv(args['V'], args['BV']), + 'OUTPUT_ATTENTIONS': lambda args: args['attn'] is not None, + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None +}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8, 16] + for num_stages in [2, 3, 4] + ], + key=["BT", "BS", "BK", "BV", "USE_G"], +) +@triton.jit +def parallel_simple_gla_fwd_kernel( + q, + k, + v, + g, + o, + attn, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NV: tl.constexpr, + OUTPUT_ATTENTIONS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_OFFSETS: tl.constexpr, + USE_G: tl.constexpr +): + tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time") + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + o += i_k * B * T * H * V + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + k += i_bh * T * K if HEAD_FIRST else (bos * H + i_h) * K + v += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + o += i_bh * T * V if HEAD_FIRST else (bos * H + i_h) * V + if USE_G: + g += i_bh * T if HEAD_FIRST else bos * H + i_h + if OUTPUT_ATTENTIONS: + attn += (bos * H + i_h * T) * T + i_k * B * H * T * T + stride_qk = K if HEAD_FIRST else H * K + stride_vo = V if HEAD_FIRST else H * V + stride_g = 1 if HEAD_FIRST else H + + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # the Q block is kept in the shared memory throughout the whole kernel + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + b_o = tl.zeros([BT, BV], dtype=tl.float32) + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BS] + o_k = i_t * BT + tl.arange(0, BS) + # Q block and K block have overlap. + # masks required + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + # [BT,] + b_gq = tl.load(p_gq, boundary_check=(0,)).to(tl.float32) + # rescale interchunk output + else: + b_gq = None + + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BS] + m_s = o_q[:, None] >= o_k[None, :] + b_s = tl.dot(b_q, b_k) + if USE_G: + p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)) + b_s *= safe_exp(b_gq[:, None] - b_gk[None, :]) + b_s = tl.where(m_s, b_s, 0) + else: + b_s = tl.where(m_s, b_s, 0) + # [BT, BV] + if i_s >= 0: + b_o += tl.dot(b_s.to(b_q.dtype), b_v) + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + o_k += BS + + for i_s in range(i_t * BT - BS, -BS, -BS): + p_k = tl.make_block_ptr(k, (K, T), (1, stride_qk), (i_k * BK, i_s), (BK, BS), (0, 1)) + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BK, BS] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BS, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_s = tl.dot(b_q, b_k) + if USE_G: + p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g) + b_gp = tl.load(g + (i_s-1) * stride_g) if i_s % BT > 0 else 0. + # No concrete meaning. Just to avoid some layout bugs. + b_s *= safe_exp(b_gq[:, None] + (b_gn - b_g)[None, :]) + b_gq += (b_gn - b_gp) + if OUTPUT_ATTENTIONS: + p_a = tl.make_block_ptr(attn, (T, T), (T, 1), (i_t * BT, i_s), (BT, BS), (1, 0)) + tl.store(p_a, b_s.to(p_a.dtype.element_ty), boundary_check=(0, 1)) + if i_s >= 0: + b_o += tl.dot(b_s.to(b_v.dtype), b_v) + p_o = tl.make_block_ptr(o, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit(do_not_specialize=['T']) +def parallel_simple_gla_bwd_kernel_dq( + i_t, + i_k, + i_v, + q, + k, + v, + g, + do, + dq, + dg, + stride_qk, + stride_vo, + stride_g, + scale, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr +): + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BT, BK] + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + + for i_s in range(0, i_t * BT, BS): + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BV] @ [BV, BS] = [BT, BS] + b_ds = tl.dot(b_do, b_v) + if USE_G: + p_g = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_gn = tl.load(g + (min(i_s + BS, T) - 1) * stride_g) + b_gp = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0. + b_ds *= safe_exp(b_gn - b_g)[None, :] + if i_s > 0: + b_dq *= safe_exp(b_gn - b_gp) + # [BT, BS] @ [BS, BK] = [BT, BK] + b_dq += tl.dot(b_ds.to(b_v.dtype), b_k) + + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + # [BT,] + b_gq = tl.load(p_gq, boundary_check=(0,)) + # [BT, BK] + b_dq *= safe_exp(b_gq)[:, None] + + # [BT] + o_q = i_t * BT + tl.arange(0, BT) + # [BS] + o_k = i_t * BT + tl.arange(0, BS) + # Q block and K block have overlap. masks required + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_v = tl.make_block_ptr(v, (V, T), (1, stride_vo), (i_v * BV, i_s), (BV, BS), (0, 1)) + # [BS, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BV, BS] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BV] @ [BV, BS] = [BT, BS] + b_ds = tl.dot(b_do, b_v) + if USE_G: + p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)) + b_ds *= safe_exp(b_gq[:, None] - b_gk[None, :]) + b_ds = tl.where(o_q[:, None] >= o_k[None, :], b_ds, 0) + # [BT, BK] + b_dq += tl.dot(b_ds.to(b_k.dtype), b_k) + o_k += BS + + b_dq *= scale + p_dq = tl.make_block_ptr(dq, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + if USE_G: + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_dg = tl.sum(b_dq * b_q, 1) + p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.jit(do_not_specialize=['T']) +def parallel_simple_gla_bwd_kernel_dkv( + i_t, + i_k, + i_v, + q, + k, + v, + g, + do, + dk, + dv, + dg, + scale, + stride_qk, + stride_vo, + stride_g, + T, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr +): + # [BT, BK] + p_k = tl.make_block_ptr(k, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + # [BT, BV] + p_v = tl.make_block_ptr(v, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_dv = tl.zeros([BT, BV], dtype=tl.float32) + if USE_G: + p_gk = tl.make_block_ptr(g, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + b_gk = tl.load(p_gk, boundary_check=(0,)) + NTS = tl.cdiv(T, BS) + # [BT, BK] + for i_s in range(NTS * BS - BS, (i_t + 1) * BT - BS, -BS): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_ds = tl.dot(b_v, tl.trans(b_do)) + b_s = tl.dot(b_k, tl.trans(b_q)) + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)) + b_gp = tl.load(g + (min(i_s + BS, T) - 1) * stride_g) + b_gn = tl.load(g + (i_s - 1) * stride_g) if i_s % BT > 0 else 0. + if i_s >= 0: + tmp = safe_exp(b_gp - b_gn) + b_dk *= tmp + b_dv *= tmp + tmp2 = safe_exp(b_gq - b_gn) + b_ds *= tmp2[None, :] + b_s *= tmp2[None, :] + # [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + # [BT, BV] + b_dv += tl.dot(b_s.to(b_do.dtype), b_do) + + if USE_G: + b_g_last = tl.load(g + (min(i_t * BT + BT, T) - 1) * stride_g) + if i_t >= 0: + tmp2 = safe_exp(b_g_last - b_gk)[:, None] + b_dk *= tmp2 + b_dv *= tmp2 + + o_q = i_t * BT + tl.arange(0, BS) + o_k = i_t * BT + tl.arange(0, BT) + for i_s in range(i_t * BT, min((i_t + 1) * BT, T), BS): + p_q = tl.make_block_ptr(q, (T, K), (stride_qk, 1), (i_s, i_k * BK), (BS, BK), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_vo, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) + # [BS, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BS, BV] + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BS] + b_ds = tl.dot(b_v, tl.trans(b_do)) + b_s = tl.dot(b_k, tl.trans(b_q)) + if USE_G: + p_gq = tl.make_block_ptr(g, (T,), (stride_g,), (i_s,), (BS,), (0,)) + b_gq = tl.load(p_gq, boundary_check=(0,)) + if i_s >= 0: + tmp = safe_exp(-b_gk[:, None] + b_gq[None, :]) + b_ds *= tmp + b_s *= tmp + m_s = o_k[:, None] <= o_q[None, :] + b_s = tl.where(m_s, b_s, 0) + b_ds = tl.where(m_s, b_ds, 0) + # [BT, BK] + b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) + b_dv += tl.dot(b_s.to(b_do.dtype), b_do) + o_q += BS + b_dk *= scale + b_dv *= scale + p_dk = tl.make_block_ptr(dk, (T, K), (stride_qk, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (stride_vo, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + if USE_G: + p_dg = tl.make_block_ptr(dg, (T,), (stride_g,), (i_t * BT,), (BT,), (0,)) + b_dg = tl.load(p_dg, boundary_check=(0,)) + b_dg -= tl.sum(b_dk * b_k, 1) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({ + 'NV': lambda args: triton.cdiv(args['V'], args['BV']), + 'USE_OFFSETS': lambda args: args['offsets'] is not None, + 'USE_G': lambda args: args['g'] is not None +}) +@triton.autotune( + configs=[ + triton.Config(triton_config, num_warps=num_warps) + for num_warps in [2, 4, 8, 16] + ], + key=['BT', 'BS', 'BK', 'BV', 'USE_G'], +) +@triton.jit(do_not_specialize=['T']) +def parallel_simple_gla_bwd_kernel( + q, + k, + v, + g, + do, + dq, + dk, + dv, + dg, + scale, + offsets, + indices, + T, + B: tl.constexpr, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + NV: tl.constexpr, + USE_OFFSETS: tl.constexpr, + HEAD_FIRST: tl.constexpr, + USE_G: tl.constexpr +): + tl.static_assert(not (USE_OFFSETS and HEAD_FIRST), "USE_OFFSETS and HEAD_FIRST cannot be True at the same time") + i_kv, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_k, i_v = i_kv // NV, i_kv % NV + i_b, i_h = i_bh // H, i_bh % H + dq += i_v * B * H * T * K + dk += i_v * B * H * T * K + dv += i_k * B * H * T * V + if USE_G: + dg += i_kv * B * H * T + + if USE_OFFSETS: + i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + q += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + k += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + v += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V + do += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V + dq += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + dk += (i_bh * T * K) if HEAD_FIRST else (bos * H + i_h) * K + dv += (i_bh * T * V) if HEAD_FIRST else (bos * H + i_h) * V + if USE_G: + g += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + dg += (i_bh * T) if HEAD_FIRST else (bos * H + i_h) + stride_qk = K if HEAD_FIRST else H * K + stride_vo = V if HEAD_FIRST else H * V + stride_g = 1 if HEAD_FIRST else H + + parallel_simple_gla_bwd_kernel_dq( + i_t=i_t, + i_k=i_k, + i_v=i_v, + q=q, + k=k, + v=v, + g=g, + do=do, + dq=dq, + dg=dg, + scale=scale, + stride_qk=stride_qk, + stride_vo=stride_vo, + stride_g=stride_g, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + USE_G=USE_G + ) + tl.debug_barrier() + parallel_simple_gla_bwd_kernel_dkv( + i_t=i_t, + i_k=i_k, + i_v=i_v, + q=q, + k=k, + v=v, + g=g, + do=do, + dk=dk, + dv=dv, + dg=dg, + scale=scale, + stride_qk=stride_qk, + stride_vo=stride_vo, + stride_g=stride_g, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + USE_G=USE_G + ) + + +def parallel_simple_gla_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + scale: float, + output_attentions: bool = False, + chunk_size: int = 128, + head_first: bool = True, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT, BS = chunk_size, 32 + if check_shared_mem('hopper', k.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + elif check_shared_mem('ampere', k.device.index): + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + else: + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert BT % BS == 0 + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + # local cumulative decay in log space + if g is not None: + g = chunk_local_cumsum(g, chunk_size, offsets=offsets, indices=indices, head_first=head_first) + grid = (NK * NV, NT, B * H) + o = torch.empty(NK, *v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device) + attn = q.new_zeros(NK, B, H, T, T) if output_attentions else None + + parallel_simple_gla_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + o=o, + attn=attn, + scale=scale, + offsets=offsets, + indices=indices, + B=B, + H=H, + T=T, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + HEAD_FIRST=head_first, + ) + o = o.sum(0) + + if output_attentions: + attn = attn.sum(0) + return o, g, attn + + +def parallel_simple_gla_bwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + scale: float, + chunk_size: int = 128, + head_first: bool = True, + offsets: Optional[torch.LongTensor] = None, + indices: Optional[torch.LongTensor] = None, +): + if head_first: + B, H, T, K, V = *k.shape, v.shape[-1] + else: + B, T, H, K, V = *k.shape, v.shape[-1] + BT, BS = chunk_size, 32 + if check_shared_mem('hopper', k.device.index): + BK = min(256, triton.next_power_of_2(K)) + BV = min(256, triton.next_power_of_2(V)) + elif check_shared_mem('ampere', k.device.index): + BK = min(128, triton.next_power_of_2(K)) + BV = min(128, triton.next_power_of_2(V)) + elif check_shared_mem('ada', k.device.index): + BK = min(64, triton.next_power_of_2(K)) + BV = min(64, triton.next_power_of_2(V)) + else: + BK = min(32, triton.next_power_of_2(K)) + BV = min(32, triton.next_power_of_2(V)) + + NK = triton.cdiv(K, BK) + NV = triton.cdiv(V, BV) + assert BT % BS == 0 + + dq = torch.empty(NV, * q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) + dk = torch.empty(NV, * k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) + dv = torch.empty(NK, * v.shape, dtype=v.dtype if NK == 1 else torch.float, device=q.device) + dg = torch.empty(NK*NV, *g.shape, dtype=torch.float, device=q.device) if g is not None else None + + NT = triton.cdiv(T, BT) if offsets is None else len(indices) + + grid = (NK * NV, NT, B * H) + parallel_simple_gla_bwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + do=do, + dq=dq, + dk=dk, + dv=dv, + dg=dg, + offsets=offsets, + indices=indices, + scale=scale, + T=T, + B=B, + H=H, + K=K, + V=V, + BT=BT, + BS=BS, + BK=BK, + BV=BV, + HEAD_FIRST=head_first + ) + dq = dq.sum(0) + dk = dk.sum(0) + dv = dv.sum(0) + dg = chunk_global_cumsum(dg.sum(0), reverse=True, head_first=head_first, offsets=offsets) if g is not None else None + return dq, dk, dv, dg + + +class ParallelSimpleGLAFunction(torch.autograd.Function): + + @staticmethod + @input_guard + @autocast_custom_fwd + def forward(ctx, q, k, v, g, scale, output_attentions, head_first, offsets): + chunk_size = 128 + ctx.dtype = q.dtype + + # 2-d indices denoting the offsets of chunks in each sequence + # for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64, + # then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be + # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] + indices = None + if offsets is not None: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(offsets[1:] - offsets[:-1], chunk_size).tolist()]) + indices = torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(offsets) + + o, g, attn = parallel_simple_gla_fwd( + q=q, + k=k, + v=v, + g=g, + scale=scale, + output_attentions=output_attentions, + head_first=head_first, + offsets=offsets, + indices=indices, + chunk_size=chunk_size) + ctx.save_for_backward(q, k, v, g, offsets, indices) + ctx.scale = scale + ctx.chunk_size = chunk_size + ctx.head_first = head_first + return o.to(q.dtype), attn + + @staticmethod + @input_guard + @autocast_custom_bwd + def backward(ctx, do, da=None): + q, k, v, g, offsets, indices = ctx.saved_tensors + dq, dk, dv, dg = parallel_simple_gla_bwd( + q=q, + k=k, + v=v, + g=g, + do=do, + scale=ctx.scale, + chunk_size=ctx.chunk_size, + offsets=offsets, + indices=indices, + head_first=ctx.head_first) + return dq.to(q), dk.to(k), dv.to(v), dg.to(ctx.dtype) if dg is not None else None, None, None, None, None + + +def parallel_simple_gla( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + output_attentions: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + k (torch.Tensor): + keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]` + v (torch.Tensor): + values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]` + g (torch.Tensor): + Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. + Compared to GLA, the gating is head-wise instead of elementwise. + scale (Optional[int]): + Scale factor for attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + output_attentions (bool): + Whether to output the materialized attention scores of shape [B, H, T, T]. Default: `False`. + head_first (Optional[bool]): + Whether the inputs are in the head-first format. Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. + attn (torch.Tensor): + Attention scores of shape `[B, H, T, T]` if `output_attentions=True` else `None` + """ + if scale is None: + scale = k.shape[-1] ** -0.5 + if cu_seqlens is not None: + assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" + assert not head_first, "head_first must be False when cu_seqlens are provided" + if g is not None: + g = g.float() + if output_attentions: + assert cu_seqlens is None, "output_attentions=True is not supported with variable-length sequences" + o, attn = ParallelSimpleGLAFunction.apply(q, k, v, g, scale, output_attentions, head_first, cu_seqlens) + return o, attn