# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import warnings from typing import Optional, Union import torch import triton import triton.language as tl from einops import rearrange from fla.ops.common.utils import prepare_chunk_indices, prepare_chunk_offsets, prepare_lens, prepare_token_indices from fla.ops.nsa.utils import _bitonic_merge from fla.ops.utils import mean_pooling from fla.ops.utils.op import exp, log from fla.utils import autocast_custom_bwd, autocast_custom_fwd, check_shared_mem, contiguous try: from flash_attn import flash_attn_func, flash_attn_varlen_func except ImportError: warnings.warn( "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", category=ImportWarning ) flash_attn_func = None @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK', 'BV'], ) @triton.jit def parallel_nsa_compression_fwd_kernel( q, k, v, o, lse, scale, offsets, token_indices, chunk_offsets, T, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BC: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr, ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos boc = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T boc = i_b * tl.cdiv(T, BS) p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) # the number of compression representations in total TC = tl.cdiv(T, BS) # the number of compression representations required to iterate over # incomplete compression blocks are not included NC = (i_t + 1) // BS p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) # [G, BV] b_o = tl.zeros([G, BV], dtype=tl.float32) # max scores for the current block b_m = tl.full([G], float('-inf'), dtype=tl.float32) # lse = log(acc) + m b_acc = tl.zeros([G], dtype=tl.float32) for i_c in range(0, NC, BC): o_c = i_c + tl.arange(0, BC) p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c, i_v * BV), (BC, BV), (1, 0)) # [BK, BC] b_k = tl.load(p_k, boundary_check=(0, 1)) # [BC, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [G, BC] b_s = tl.dot(b_q, b_k) b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) # [G] b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m b_r = exp(b_mp - b_m) # [G, BC] b_p = exp(b_s - b_m[:, None]) # [G] b_acc = b_acc * b_r + tl.sum(b_p, 1) # [G, BV] b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) b_mp = b_m if NC == 0: b_lse = tl.zeros([G], dtype=tl.float32) else: b_o = b_o / b_acc[:, None] b_lse = b_m + log(b_acc) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) if i_v == 0: tl.store(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G), b_lse.to(lse.dtype.element_ty)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None, }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK', 'BV'], ) @triton.jit(do_not_specialize=['T']) def parallel_nsa_compression_bwd_kernel_dq( q, k, v, lse, delta, do, dq, scale, offsets, token_indices, chunk_offsets, T, B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BC: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos boc = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T boc = i_b * tl.cdiv(T, BS) q += (bos + i_t) * HQ*K do += (bos + i_t) * HQ*V lse += (bos + i_t) * HQ delta += (bos + i_t) * HQ dq += (i_v * B * T + bos + i_t) * HQ*K p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + i_h * G + tl.arange(0, G) p_delta = delta + i_h * G + tl.arange(0, G) # the number of compression representations in total TC = tl.cdiv(T, BS) # the number of compression representations required to iterate over # incomplete compression blocks are not included NC = (i_t + 1) // BS # [G, BV] b_do = tl.load(p_do, boundary_check=(0, 1)) # [G] b_lse = tl.load(p_lse) b_delta = tl.load(p_delta) # [G, BK] b_dq = tl.zeros([G, BK], dtype=tl.float32) for i_c in range(0, NC, BC): o_c = i_c + tl.arange(0, BC) p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (V, TC), (1, H*V), (i_v * BV, i_c), (BV, BC), (0, 1)) # [BK, BC] b_k = tl.load(p_k, boundary_check=(0, 1)) # [BV, BC] b_v = tl.load(p_v, boundary_check=(0, 1)) # [G, BC] b_s = tl.dot(b_q, b_k) b_p = exp(b_s - b_lse[:, None]) b_p = tl.where((o_c < NC)[None, :], b_p, 0) # [G, BV] @ [BV, BC] -> [G, BC] b_dp = tl.dot(b_do, b_v) b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) # [G, BC] @ [BC, BK] -> [G, BK] b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) b_dq *= scale tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK', 'BV'], ) @triton.jit(do_not_specialize=['T']) def parallel_nsa_compression_bwd_kernel_dkv( q, k, v, lse, delta, do, dk, dv, offsets, chunk_indices, chunk_offsets, scale, T, B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BC: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_v, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_c = tl.load(chunk_indices + i_c * 2).to(tl.int32), tl.load(chunk_indices + i_c * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos boc = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T boc = i_b * tl.cdiv(T, BS) # the number of compression representations in total TC = tl.cdiv(T, BS) p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) p_v = tl.make_block_ptr(v + (boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + boc * H + i_h) * K, (TC, K), (H*K, 1), (i_c * BC, 0), (BC, BK), (1, 0)) p_dv = tl.make_block_ptr(dv + (i_v * B*T*H + boc * H + i_h) * V, (TC, V), (H*V, 1), (i_c * BC, i_v * BV), (BC, BV), (1, 0)) # [BC, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_dk = tl.zeros([BC, BK], dtype=tl.float32) # [BC, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) b_dv = tl.zeros([BC, BV], dtype=tl.float32) for i in range(i_c * BC * BS, T): o_c = i_c * BC + tl.arange(0, BC) p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_do = tl.load(p_do, boundary_check=(0, 1)) # [G] b_lse = tl.load(p_lse) b_delta = tl.load(p_delta) # [BC, G] b_s = tl.dot(b_k, tl.trans(b_q)) b_p = exp(b_s - b_lse[None, :]) b_p = tl.where((i >= max(0, (o_c + 1) * BS - 1))[:, None], b_p, 0) # [BC, G] @ [G, BV] -> [BC, BV] b_dv += tl.dot(b_p.to(b_do.dtype), b_do) # [BC, BV] @ [BV, G] -> [BC, G] b_dp = tl.dot(b_v, tl.trans(b_do)) # [BC, G] b_ds = b_p * (b_dp - b_delta[None, :]) # [BC, G] @ [G, BK] -> [BC, BK] b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK'], ) @triton.jit def parallel_nsa_kernel_topk( q, k, lse, scale, block_indices, offsets, token_indices, chunk_offsets, T, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, S: tl.constexpr, BC: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, USE_OFFSETS: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos boc = tl.load(chunk_offsets + i_n).to(tl.int32) else: bos, eos = i_b * T, i_b * T + T boc = i_b * tl.cdiv(T, BS) p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) # the number of compression representations in total TC = tl.cdiv(T, BS) # the number of compression representations required to iterate over # incomplete compression blocks are not included NC = (i_t + 1) // BS ################################ # 1. lse computation ################################ if lse is not None: b_lse = tl.load(lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)) else: # max scores for the current block b_m = tl.full([G], float('-inf'), dtype=tl.float32) # lse = log(acc) + m b_acc = tl.zeros([G], dtype=tl.float32) for i_c in range(0, NC, BC): o_c = i_c + tl.arange(0, BC) p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) # [BK, BC] b_k = tl.load(p_k, boundary_check=(0, 1)) # [G, BC] b_s = tl.dot(b_q, b_k) b_s = tl.where((o_c < NC)[None, :], b_s, float('-inf')) # [G] b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m b_r = exp(b_mp - b_m) # [G, BC] b_p = exp(b_s - b_m[:, None]) # [G] b_acc = b_acc * b_r + tl.sum(b_p, 1) b_mp = b_m if NC == 0: b_lse = tl.zeros([G], dtype=tl.float32) else: b_lse = b_m + log(b_acc) ################################ # 2. topk selection ################################ # [BC] b_i = tl.full([BC], -1, dtype=tl.float32) o_i = tl.zeros([BC], dtype=tl.int32) m_i = tl.arange(0, BC) < BC//2 for i_c in range(0, i_t // BS + 1, BC): o_c = i_c + tl.arange(0, BC) p_k = tl.make_block_ptr(k + (boc * H + i_h) * K, (K, TC), (1, H*K), (0, i_c), (BK, BC), (0, 1)) # [BK, BC] b_k = tl.load(p_k, boundary_check=(0, 1)) # [G, BC] b_s = tl.dot(b_q, b_k) b_s = tl.where((i_t // BS > o_c)[None, :], b_s, float('-inf')) # [G, BC] b_p = tl.where((i_t // BS == o_c)[None, :], float(1.0), exp(b_s - b_lse[:, None])) # the importance scores of the current block # [BC] b_i, b_ip = tl.sum(b_p, 0), b_i o_i, o_ip = tl.where(o_c <= i_t // BS, o_c + 1, 0), o_i n_dims: tl.constexpr = tl.standard._log2(b_i.shape[0]) for i in tl.static_range(1, n_dims): b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), i, 2, n_dims) if i_c != 0: b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, False, n_dims) b_i_new = b_ip * m_i + b_i * (1 - m_i) o_i_new = o_ip * m_i + o_i * (1 - m_i) b_i, o_i = _bitonic_merge(b_i_new, o_i_new.to(tl.int32), n_dims, True, n_dims) else: b_i, o_i = _bitonic_merge(b_i, o_i.to(tl.int32), n_dims, True, n_dims) m_top = tl.arange(0, BC//S) == 0 b_top = tl.sum(m_top[:, None] * tl.reshape(o_i - 1, [BC//S, S]), 0) p_b = tl.make_block_ptr(block_indices + (bos + i_t) * H*S, (H*S,), (1,), (i_h * S,), (S,), (0,)) tl.store(p_b, b_top.to(p_b.dtype.element_ty)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None, 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor), }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK', 'BV'], ) @triton.jit def parallel_nsa_fwd_kernel( q, k, v, o, lse, scale, block_indices, block_counts, offsets, token_indices, T, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T k += (bos * H + i_h) * K v += (bos * H + i_h) * V block_indices += (bos + i_t) * H*S + i_h * S if USE_BLOCK_COUNTS: NS = tl.load(block_counts + (bos + i_t) * H + i_h) else: NS = S p_q = tl.make_block_ptr(q + (bos + i_t) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # the Q block is kept in the shared memory throughout the whole kernel # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) p_o = tl.make_block_ptr(o + (bos + i_t) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + (bos + i_t) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_o = tl.zeros([G, BV], dtype=tl.float32) b_m = tl.full([G], float('-inf'), dtype=tl.float32) b_acc = tl.zeros([G], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS if i_s <= i_t and i_s >= 0: p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_s, i_v * BV), (BS, BV), (1, 0)) # [BK, BS] b_k = tl.load(p_k, boundary_check=(0, 1)) # [BS, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [G, BS] b_s = tl.dot(b_q, b_k) b_s = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s, float('-inf')) # [G] b_m, b_mp = tl.maximum(b_m, tl.max(b_s, 1)), b_m b_r = exp(b_mp - b_m) # [G, BS] b_p = exp(b_s - b_m[:, None]) # [G] b_acc = b_acc * b_r + tl.sum(b_p, 1) # [G, BV] b_o = b_o * b_r[:, None] + tl.dot(b_p.to(b_q.dtype), b_v) b_mp = b_m b_o = b_o / b_acc[:, None] b_m += log(b_acc) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_lse, b_m.to(p_lse.dtype.element_ty)) @triton.heuristics({ 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) }) @triton.jit def parallel_nsa_kernel_mask( block_indices, block_counts, block_mask, T: tl.constexpr, H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr ): i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_h, i_s = i_hs // S, i_hs % S b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s) if USE_BLOCK_COUNTS: b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h) else: b_m = b_i * BS <= i_t if b_i < NS and b_i >= 0: tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i, b_m.to(block_mask.dtype.element_ty)) @triton.jit def parallel_nsa_bwd_kernel_preprocess( o, do, delta, B: tl.constexpr, V: tl.constexpr ): i_n = tl.program_id(0) o_d = tl.arange(0, B) m_d = o_d < V b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0) b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32) b_delta = tl.sum(b_o * b_do) tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None, 'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor) }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK', 'BV'], ) @triton.jit(do_not_specialize=['T']) def parallel_nsa_bwd_kernel_dq( q, k, v, lse, delta, do, dq, scale, block_indices, block_counts, offsets, token_indices, T, B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr ): i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T q += (bos + i_t) * HQ*K do += (bos + i_t) * HQ*V lse += (bos + i_t) * HQ delta += (bos + i_t) * HQ dq += (i_v * B * T + bos + i_t) * HQ*K block_indices += (bos + i_t) * H*S + i_h * S if USE_BLOCK_COUNTS: NS = tl.load(block_counts + (bos + i_t) * H + i_h) else: NS = S k += (bos * H + i_h) * K v += (bos * H + i_h) * V p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) p_do = tl.make_block_ptr(do, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + i_h * G + tl.arange(0, G) p_delta = delta + i_h * G + tl.arange(0, G) # [G, BV] b_do = tl.load(p_do, boundary_check=(0, 1)) # [G] b_lse = tl.load(p_lse) b_delta = tl.load(p_delta) # [G, BK] b_dq = tl.zeros([G, BK], dtype=tl.float32) for i in range(NS): i_s = tl.load(block_indices + i).to(tl.int32) * BS if i_s <= i_t and i_s >= 0: p_k = tl.make_block_ptr(k, (K, T), (1, H*K), (0, i_s), (BK, BS), (0, 1)) p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_s), (BV, BS), (0, 1)) # [BK, BS] b_k = tl.load(p_k, boundary_check=(0, 1)) # [BV, BS] b_v = tl.load(p_v, boundary_check=(0, 1)) # [G, BS] b_s = tl.dot(b_q, b_k) b_p = exp(b_s - b_lse[:, None]) b_p = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p, 0) # [G, BV] @ [BV, BS] -> [G, BS] b_dp = tl.dot(b_do, b_v) b_ds = b_p * (b_dp.to(tl.float32) - b_delta[:, None]) # [G, BS] @ [BS, BK] -> [G, BK] b_dq += tl.dot(b_ds.to(b_k.dtype), tl.trans(b_k)) b_dq *= scale tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4] ], key=['BS', 'BK', 'BV'], ) @triton.jit(do_not_specialize=['T']) def parallel_nsa_bwd_kernel_dkv( q, k, v, lse, delta, do, dk, dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) p_dk = tl.make_block_ptr(dk + (i_v * B*T*H + bos * H + i_h) * K, (T, K), (H*K, 1), (i_s * BS, 0), (BS, BK), (1, 0)) p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_s * BS, i_v * BV), (BS, BV), (1, 0)) # [BS, BK] b_k = tl.load(p_k, boundary_check=(0, 1)) b_dk = tl.zeros([BS, BK], dtype=tl.float32) # [BS, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) b_dv = tl.zeros([BS, BV], dtype=tl.float32) for i in range(i_s * BS, T): b_m = tl.load(block_mask + (bos + i) * H*M + i_h * M + i_s) if b_m: p_q = tl.make_block_ptr(q + (bos + i) * HQ*K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0)) # [G, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) p_do = tl.make_block_ptr(do + (bos + i) * HQ*V, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0)) p_lse = lse + (bos + i) * HQ + i_h * G + tl.arange(0, G) p_delta = delta + (bos + i) * HQ + i_h * G + tl.arange(0, G) # [G, BV] b_do = tl.load(p_do, boundary_check=(0, 1)) # [G] b_lse = tl.load(p_lse) b_delta = tl.load(p_delta) # [BS, G] b_s = tl.dot(b_k, tl.trans(b_q)) b_p = exp(b_s - b_lse[None, :]) b_p = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p, 0) # [BS, G] @ [G, BV] -> [BS, BV] b_dv += tl.dot(b_p.to(b_do.dtype), b_do) # [BS, BV] @ [BV, G] -> [BS, G] b_dp = tl.dot(b_v, tl.trans(b_do)) # [BS, G] b_ds = b_p * (b_dp - b_delta[None, :]) # [BS, G] @ [G, BK] -> [BS, BK] b_dk += tl.dot(b_ds.to(b_q.dtype), b_q) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) def parallel_nsa_compression_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_size: int, scale: float, offsets: Optional[torch.LongTensor] = None, token_indices: Optional[torch.LongTensor] = None, ): B, T, HQ, K, V = *q.shape, v.shape[-1] H = k.shape[2] G = HQ // H BC = BS = block_size if check_shared_mem('hopper', q.device.index): BK = min(256, triton.next_power_of_2(K)) BV = min(256, triton.next_power_of_2(V)) else: BK = min(128, triton.next_power_of_2(K)) BV = min(128, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None grid = (T, NV, B * H) o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) parallel_nsa_compression_fwd_kernel[grid]( q=q, k=k, v=v, o=o, lse=lse, scale=scale, offsets=offsets, token_indices=token_indices, chunk_offsets=chunk_offsets, T=T, H=H, HQ=HQ, G=G, K=K, V=V, BC=BC, BS=BS, BK=BK, BV=BV, ) return o, lse def parallel_nsa_compression_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, lse: torch.Tensor, do: torch.Tensor, block_size: int = 64, scale: float = None, offsets: Optional[torch.LongTensor] = None, token_indices: Optional[torch.LongTensor] = None, ): B, T, HQ, K, V = *q.shape, v.shape[-1] H = k.shape[2] G = HQ // H BC = BS = block_size BK = triton.next_power_of_2(K) BV = min(128, triton.next_power_of_2(v.shape[-1])) NV = triton.cdiv(V, BV) if offsets is not None: lens = prepare_lens(offsets) chunk_indices = torch.cat([torch.arange(n) for n in triton.cdiv(triton.cdiv(lens, BS), BC).tolist()]) chunk_indices = torch.stack([chunk_indices.eq(0).cumsum(0) - 1, chunk_indices], 1).to(offsets) chunk_offsets = prepare_chunk_offsets(offsets, BS) NC = len(chunk_indices) else: chunk_indices, chunk_offsets = None, None NC = triton.cdiv(triton.cdiv(T, BS), BC) delta = parallel_nsa_bwd_preprocess(o, do) dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) grid = (T, NV, B * H) parallel_nsa_compression_bwd_kernel_dq[grid]( q=q, k=k, v=v, lse=lse, delta=delta, do=do, dq=dq, scale=scale, offsets=offsets, token_indices=token_indices, chunk_offsets=chunk_offsets, T=T, B=B, H=H, HQ=HQ, G=G, K=K, V=V, BC=BC, BS=BS, BK=BK, BV=BV ) dq = dq.sum(0) dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) grid = (NV, NC, B * H) parallel_nsa_compression_bwd_kernel_dkv[grid]( q=q, k=k, v=v, lse=lse, delta=delta, do=do, dk=dk, dv=dv, offsets=offsets, chunk_indices=chunk_indices, chunk_offsets=chunk_offsets, scale=scale, T=T, B=B, H=H, HQ=HQ, G=G, K=K, V=V, BC=BC, BS=BS, BK=BK, BV=BV ) dk = dk.sum(0) return dq, dk, dv class ParallelNSACompressionFunction(torch.autograd.Function): @staticmethod @contiguous @autocast_custom_fwd def forward( ctx, q, k, v, block_size, scale, offsets ): ctx.dtype = q.dtype # 2-d sequence indices denoting the offsets of tokens in each sequence # for example, if the passed `offsets` is [0, 2, 6], # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None o, lse = parallel_nsa_compression_fwd( q=q, k=k, v=v, block_size=block_size, scale=scale, offsets=offsets, token_indices=token_indices ) ctx.save_for_backward(q, k, v, o, lse) ctx.offsets = offsets ctx.token_indices = token_indices ctx.block_size = block_size ctx.scale = scale return o.to(q.dtype), lse @staticmethod @contiguous @autocast_custom_bwd def backward(ctx, do, *args): q, k, v, o, lse = ctx.saved_tensors dq, dk, dv = parallel_nsa_compression_bwd( q=q, k=k, v=v, o=o, lse=lse, do=do, block_size=ctx.block_size, scale=ctx.scale, offsets=ctx.offsets, token_indices=ctx.token_indices ) return dq.to(q), dk.to(k), dv.to(v), None, None, None def parallel_nsa_topk( q: torch.Tensor, k: torch.Tensor, lse: torch.Tensor, block_counts: Union[torch.LongTensor, int], block_size: int = 64, scale: float = None, offsets: Optional[torch.LongTensor] = None, ) -> torch.LongTensor: B, T, HQ, K = q.shape H = k.shape[2] G = HQ // H S = block_counts if isinstance(block_counts, int) else block_counts.max().item() S = triton.next_power_of_2(S) # here we set BC = BS, but beware that they are actually decoupled BC = BS = block_size BK = triton.next_power_of_2(K) block_indices = torch.zeros(B, T, H, S, dtype=torch.int32, device=q.device) token_indices = prepare_token_indices(offsets) if offsets is not None else None chunk_offsets = prepare_chunk_offsets(offsets, BS) if offsets is not None else None grid = (T, B * H) parallel_nsa_kernel_topk[grid]( q=q, k=k, lse=lse, scale=scale, block_indices=block_indices, offsets=offsets, token_indices=token_indices, chunk_offsets=chunk_offsets, T=T, H=H, HQ=HQ, G=G, K=K, S=S, BC=BC, BS=BS, BK=BK ) return block_indices def parallel_nsa_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_indices: torch.LongTensor, block_counts: Union[torch.LongTensor, int], block_size: int, scale: float, offsets: Optional[torch.LongTensor] = None, token_indices: Optional[torch.LongTensor] = None, ): B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] HQ = q.shape[2] G = HQ // H BS = block_size if check_shared_mem('hopper', q.device.index): BK = min(256, triton.next_power_of_2(K)) BV = min(256, triton.next_power_of_2(V)) else: BK = min(128, triton.next_power_of_2(K)) BV = min(128, triton.next_power_of_2(V)) NK = triton.cdiv(K, BK) NV = triton.cdiv(V, BV) assert NK == 1, "The key dimension can not be larger than 256" grid = (T, NV, B * H) o = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) lse = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) parallel_nsa_fwd_kernel[grid]( q=q, k=k, v=v, o=o, lse=lse, scale=scale, block_indices=block_indices, block_counts=block_counts, offsets=offsets, token_indices=token_indices, T=T, H=H, HQ=HQ, G=G, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV, ) return o, lse def parallel_nsa_block_mask( block_indices: torch.LongTensor, block_counts: Union[torch.LongTensor, int], offsets: torch.LongTensor, block_size: int, ): B, T, H, S = block_indices.shape BS = block_size if offsets is not None: NS = triton.cdiv(prepare_lens(offsets).max().item(), BS) else: NS = triton.cdiv(T, BS) block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device) parallel_nsa_kernel_mask[(T, B, H*S)]( block_indices=block_indices, block_counts=block_counts, block_mask=block_mask, T=T, H=H, S=S, BS=BS, NS=NS ) return block_mask def parallel_nsa_bwd_preprocess( o: torch.Tensor, do: torch.Tensor ): V = o.shape[-1] delta = torch.empty_like(o[..., 0], dtype=torch.float32) parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)]( o=o, do=do, delta=delta, B=triton.next_power_of_2(V), V=V, ) return delta def parallel_nsa_bwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, o: torch.Tensor, lse: torch.Tensor, do: torch.Tensor, block_indices: torch.Tensor, block_counts: Union[torch.LongTensor, int], block_size: int = 64, scale: float = None, offsets: Optional[torch.LongTensor] = None, token_indices: Optional[torch.LongTensor] = None, ): B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1] HQ = q.shape[2] G = HQ // H BS = block_size BK = triton.next_power_of_2(K) BV = min(128, triton.next_power_of_2(v.shape[-1])) NV = triton.cdiv(V, BV) delta = parallel_nsa_bwd_preprocess(o, do) dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device) grid = (T, NV, B * H) parallel_nsa_bwd_kernel_dq[grid]( q=q, k=k, v=v, lse=lse, delta=delta, do=do, dq=dq, block_indices=block_indices, block_counts=block_counts, offsets=offsets, token_indices=token_indices, scale=scale, T=T, B=B, H=H, HQ=HQ, G=G, K=K, V=V, S=S, BS=BS, BK=BK, BV=BV ) dq = dq.sum(0) if offsets is not None: chunk_indices = prepare_chunk_indices(offsets, BS) NS = len(chunk_indices) else: chunk_indices = None NS = triton.cdiv(T, BS) # [B, T, H, M] block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size) dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device) dv = torch.empty(v.shape, dtype=v.dtype, device=q.device) grid = (NV, NS, B * H) parallel_nsa_bwd_kernel_dkv[grid]( q=q, k=k, v=v, lse=lse, delta=delta, do=do, dk=dk, dv=dv, block_mask=block_mask, offsets=offsets, chunk_indices=chunk_indices, scale=scale, T=T, B=B, H=H, HQ=HQ, G=G, K=K, V=V, M=block_mask.shape[-1], BS=BS, BK=BK, BV=BV ) dk = dk.sum(0) return dq, dk, dv @torch.compile class ParallelNSAFunction(torch.autograd.Function): @staticmethod @contiguous @autocast_custom_fwd def forward(ctx, q, k, v, block_indices, block_counts, block_size, scale, offsets): ctx.dtype = q.dtype # 2-d sequence indices denoting the offsets of tokens in each sequence # for example, if the passed `offsets` is [0, 2, 6], # then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be # [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]] token_indices = prepare_token_indices(offsets) if offsets is not None else None o, lse = parallel_nsa_fwd( q=q, k=k, v=v, block_indices=block_indices, block_counts=block_counts, block_size=block_size, scale=scale, offsets=offsets, token_indices=token_indices ) ctx.save_for_backward(q, k, v, o, lse) ctx.block_indices = block_indices ctx.block_counts = block_counts ctx.offsets = offsets ctx.token_indices = token_indices ctx.block_size = block_size ctx.scale = scale return o.to(q.dtype) @staticmethod @contiguous @autocast_custom_bwd def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors dq, dk, dv = parallel_nsa_bwd( q=q, k=k, v=v, o=o, lse=lse, do=do, block_indices=ctx.block_indices, block_counts=ctx.block_counts, block_size=ctx.block_size, scale=ctx.scale, offsets=ctx.offsets, token_indices=ctx.token_indices ) return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None def parallel_nsa_compression( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, block_size: int = 64, scale: float = None, offsets: Optional[torch.LongTensor] = None ): return ParallelNSACompressionFunction.apply( q, k, v, block_size, scale, offsets ) def parallel_nsa( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g_cmp: torch.Tensor, g_slc: torch.Tensor, g_swa: torch.Tensor, block_indices: Optional[torch.LongTensor] = None, block_counts: Union[torch.LongTensor, int] = 16, block_size: int = 64, window_size: int = 0, scale: Optional[float] = None, cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False ) -> torch.Tensor: r""" Args: q (torch.Tensor): queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`. k (torch.Tensor): keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16. v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g_cmp (torch.Tensor): Gate score for compressed attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. g_slc (torch.Tensor): Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. g_swa (torch.Tensor): Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`. block_indices (torch.LongTensor): Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`. `S` is the number of selected blocks for each query token, which is set to 16 in the paper. If `g_cmp` is provided, the passed `block_indices` will be ignored. block_counts (Optional[Union[torch.LongTensor, int]]): Number of selected blocks for each query. If a tensor is provided, with shape `[B, T, H]` if `head_first=False` else `[B, H, T]`, each query can select the same number of blocks. If not provided, it will default to 16. block_size (int): Selected block size. Default: 64. window_size (int): Sliding window size. Default: 0. scale (Optional[int]): Scale factor for attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. head_first (Optional[bool]): Whether the inputs are in the head-first format. Default: `False`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. Returns: o (torch.Tensor): Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`. """ assert block_counts is not None, "block counts must be provided for selection" if scale is None: scale = k.shape[-1] ** -0.5 if cu_seqlens is not None: assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided" if head_first: q, k, v = map(lambda x: rearrange(x, 'b h t d -> b t h d'), (q, k, v)) g_cmp, g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h') if x is not None else None, (g_cmp, g_slc, g_swa)) if not isinstance(block_counts, int): block_counts = rearrange(block_counts, 'b h t -> b t h') assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA" k_cmp, v_cmp = mean_pooling(k, block_size, cu_seqlens), mean_pooling(v, block_size, cu_seqlens) o_cmp, lse_cmp = None, None if g_cmp is not None: o_cmp, lse_cmp = parallel_nsa_compression( q=q, k=k_cmp, v=v_cmp, block_size=block_size, scale=scale, offsets=cu_seqlens ) if block_indices is not None: warnings.warn("`block_indices` will be ignored when `g_cmp` is provided") block_indices = parallel_nsa_topk( q=q, k=k_cmp, lse=lse_cmp, block_counts=block_counts, block_size=block_size, scale=scale, offsets=cu_seqlens ) o_slc = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size, scale, cu_seqlens) o = o_slc * g_slc.unsqueeze(-1) if o_cmp is not None: o = torch.addcmul(o, o_cmp, g_cmp.unsqueeze(-1)) if window_size > 0: if cu_seqlens is not None: max_seqlen = q.shape[1] o_swa = flash_attn_varlen_func( q.squeeze(0), k.squeeze(0), v.squeeze(0), cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=True, window_size=(window_size-1, 0) ).unsqueeze(0) else: o_swa = flash_attn_func( q, k, v, causal=True, window_size=(window_size-1, 0) ) o = torch.addcmul(o, o_swa, g_swa.unsqueeze(-1)) if head_first: o = rearrange(o, 'b t h d -> b h t d') return o