import math import torch import triton import triton.language as tl @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _bwd_eva_agg_kernel_dkdv( Q, K, V, WindowMask, DO, LSE, DO_T_O, DK, DV, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_window_mask_b, stride_window_mask_m, stride_do_b, stride_do_h, stride_do_m, stride_lse_b, stride_lse_h, stride_do_t_o_b, stride_do_t_o_h, stride_dk_b, stride_dk_h, stride_dk_n, stride_dv_b, stride_dv_h, stride_dv_n, nheads, seqlen_q, seqlen_k, headdim, WINDOW_SIZE: tl.constexpr, MASK_TYPE: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_W: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_bh = tl.program_id(1) off_h = off_bh % nheads off_b = off_bh // nheads start_n = tl.program_id(0) # determine which window the current KV block belongs to offs_w = (start_n * BLOCK_N) // WINDOW_SIZE offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) # initialize pointers q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] ) k_ptrs = ( K + off_b * stride_kb + off_h * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] ) v_ptrs = ( V + off_b * stride_vb + off_h * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] ) do_ptrs = ( DO + off_b * stride_do_b + off_h * stride_do_h + offs_m[:, None] * stride_do_m + offs_d[None, :] ) do_t_o_ptrs = ( DO_T_O + off_b * stride_do_t_o_b + off_h * stride_do_t_o_h + offs_m[:, None] ) lse_ptrs = ( LSE + off_b * stride_lse_b + off_h * stride_lse_h + offs_m[:, None] ) if MASK_TYPE == 1: m_ptrs = ( WindowMask + off_b * stride_window_mask_b + (offs_m[:, None] * stride_window_mask_m + offs_n[None, :]) ) dk_ptrs = ( DK + off_b * stride_dk_b + off_h * stride_dk_h + offs_n[:, None] * stride_dk_n + offs_d[None, :] ) dv_ptrs = ( DV + off_b * stride_dv_b + off_h * stride_dv_h + offs_n[:, None] * stride_dv_n + offs_d[None, :] ) # 1. for singletons # determine start and end of query block begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) else: k = tl.load( k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) v = tl.load( v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 ) for start_m in range(begin_m, end_m, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) # load q, do, and lse if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load( q_ptrs + start_m * stride_qm ) do = tl.load( do_ptrs + start_m * stride_do_m ) else: q = tl.load( q_ptrs + start_m * stride_qm, mask=offs_d[None, :] < headdim, other=0.0 ) do = tl.load( do_ptrs + start_m * stride_do_m, mask=offs_d[None, :] < headdim, other=0.0 ) do_t_o = tl.load( do_t_o_ptrs + start_m ) lse = tl.load( lse_ptrs + start_m ) else: if EVEN_HEADDIM: q = tl.load( q_ptrs + start_m * stride_qm, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) do = tl.load( do_ptrs + start_m * stride_do_m, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) else: q = tl.load( q_ptrs + start_m * stride_qm, mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) do = tl.load( do_ptrs + start_m * stride_do_m, mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) do_t_o = tl.load( do_t_o_ptrs + start_m, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) lse = tl.load( lse_ptrs + start_m, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) lse = tl.where(lse == float("-inf"), 0.0, lse) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) if not EVEN_M: qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf")) if MASK_TYPE == 1: if EVEN_M & EVEN_W: mask = tl.load( m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) ) else: mask = tl.load( m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE), mask=((start_m + offs_m)[:, None] < seqlen_q) & (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE), other=1, ) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. # we assume mask already implies the causal masking qk = qk * softmax_scale qk = tl.where(mask, float("-inf"), qk) p = tl.exp(qk - lse) else: qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf")) p = tl.exp(qk * softmax_scale - lse) # dp [M, N] dp = tl.dot(do, tl.trans(v)) # p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N] ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) # p is fp32 and [M, N], convert to q.dtype # do [M, D] -> dv [N, D] dv += tl.dot(tl.trans(p.to(do.dtype)), do) # dk [N, D] dk += tl.dot(tl.trans(ds), q) if EVEN_N & EVEN_M: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) else: tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) else: tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _bwd_eva_agg_kernel_drfa_kv( Q, RFA_K, RFA_V, ChunkMask, DO, LSE, DO_T_O, D_RFA_K, D_RFA_V, softmax_scale, stride_qb, stride_qh, stride_qm, stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, stride_chunk_mask_b, stride_chunk_mask_m, stride_do_b, stride_do_h, stride_do_m, stride_lse_b, stride_lse_h, stride_do_t_o_b, stride_do_t_o_h, stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c, stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c, nheads, seqlen_q, nchunks, headdim, CHUNKS_PER_WINDOW: tl.constexpr, WINDOW_SIZE: tl.constexpr, MASK_TYPE: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_C: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_bh = tl.program_id(1) off_h = off_bh % nheads off_b = off_bh // nheads start_c = tl.program_id(0) # there are 128 chunks per window offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N) # determine which window the current KV block belongs to offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) # initialize pointers q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) do_ptrs = ( DO + off_b * stride_do_b + off_h * stride_do_h + (offs_m[:, None] * stride_do_m + offs_d[None, :]) ) do_t_o_ptrs = ( DO_T_O + off_b * stride_do_t_o_b + off_h * stride_do_t_o_h + (offs_m[:, None]) ) lse_ptrs = ( LSE + off_b * stride_lse_b + off_h * stride_lse_h + (offs_m[:, None]) ) rfa_k_ptrs = ( RFA_K + off_b * stride_rfa_kb + off_h * stride_rfa_kh + (offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) ) rfa_v_ptrs = ( RFA_V + off_b * stride_rfa_vb + off_h * stride_rfa_vh + (offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) ) if MASK_TYPE == 1: rfa_m_ptrs = ( ChunkMask + off_b * stride_chunk_mask_b + (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :]) ) d_rfa_k_ptrs = ( D_RFA_K + off_b * stride_d_rfa_k_b + off_h * stride_d_rfa_k_h + (offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :]) ) d_rfa_v_ptrs = ( D_RFA_V + off_b * stride_d_rfa_v_b + off_h * stride_d_rfa_v_h + (offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :]) ) d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) if EVEN_C & EVEN_M: if EVEN_HEADDIM: rfa_k = tl.load(rfa_k_ptrs) rfa_v = tl.load(rfa_v_ptrs) else: rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0) rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0) else: rfa_k = tl.load( rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0 ) rfa_v = tl.load( rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0 ) begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q) end_m = seqlen_q for start_m in range(begin_m, end_m, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) # load q, do, and lse if EVEN_M: if EVEN_HEADDIM: q = tl.load( q_ptrs + start_m * stride_qm ) do = tl.load( do_ptrs + start_m * stride_do_m ) else: q = tl.load( q_ptrs + start_m * stride_qm, mask=offs_d[None, :] < headdim, other=0.0 ) do = tl.load( do_ptrs + start_m * stride_do_m, mask=offs_d[None, :] < headdim, other=0.0 ) do_t_o = tl.load( do_t_o_ptrs + start_m ) lse = tl.load( lse_ptrs + start_m ) else: if EVEN_HEADDIM: q = tl.load( q_ptrs + start_m * stride_qm, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) do = tl.load( do_ptrs + start_m * stride_do_m, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) else: q = tl.load( q_ptrs + start_m * stride_qm, mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) do = tl.load( do_ptrs + start_m * stride_do_m, mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) do_t_o = tl.load( do_t_o_ptrs + start_m, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) lse = tl.load( lse_ptrs + start_m, mask=(start_m + offs_m)[:, None] < seqlen_q, other=0.0 ) lse = tl.where(lse == float("-inf"), 0.0, lse) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(rfa_k)) if not EVEN_M: qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf")) if MASK_TYPE == 1: if EVEN_M & EVEN_C: mask = tl.load( rfa_m_ptrs + (start_m * stride_chunk_mask_m) ) else: mask = tl.load( rfa_m_ptrs + (start_m * stride_chunk_mask_m), mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_c[None, :] < nchunks), other=1, ) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. # we assume mask already implies the causal masking qk = qk * softmax_scale qk = tl.where(mask, float("-inf"), qk) p = tl.exp(qk - lse) else: p = tl.exp(qk * softmax_scale - lse) dp = tl.dot(do, tl.trans(rfa_v)) ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) # p is fp32, convert to q.dtype d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do) # move softmax_scale to ds to save computation d_rfa_k += tl.dot(tl.trans(ds), q) if EVEN_C & EVEN_M: if EVEN_HEADDIM: tl.store(d_rfa_v_ptrs, d_rfa_v) tl.store(d_rfa_k_ptrs, d_rfa_k) else: tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim) tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks) tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks) else: tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim)) tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim)) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _bwd_eva_agg_kernel_dq( Q, K, V, RFA_K, RFA_V, WindowMask, ChunkMask, DO, LSE, DO_T_O, DQ, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, stride_window_mask_b, stride_window_mask_m, stride_chunk_mask_b, stride_chunk_mask_m, stride_do_b, stride_do_h, stride_do_m, stride_lse_b, stride_lse_h, stride_do_t_o_b, stride_do_t_o_h, stride_dq_b, stride_dq_h, stride_dq_m, nheads, seqlen_q, seqlen_k, nchunks, headdim, CHUNKS_PER_WINDOW: tl.constexpr, WINDOW_SIZE: tl.constexpr, MASK_TYPE: tl.constexpr, EMPTY_RFA_KV: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_W: tl.constexpr, EVEN_C: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_bh = tl.program_id(1) off_h = off_bh % nheads off_b = off_bh // nheads # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_w = (start_m * BLOCK_M) // WINDOW_SIZE offs_n = tl.arange(0, BLOCK_N) offs_c = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # TODO: add paratheses or not q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) if EMPTY_RFA_KV == 0: rfa_k_ptrs = ( RFA_K + off_b * stride_rfa_kb + off_h * stride_rfa_kh + (offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) ) rfa_v_ptrs = ( RFA_V + off_b * stride_rfa_vb + off_h * stride_rfa_vh + (offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) ) dq_ptrs = ( DQ + off_b * stride_dq_b + off_h * stride_dq_h + (offs_m[:, None] * stride_dq_m + offs_d[None, :]) ) do_ptrs = ( DO + off_b * stride_do_b + off_h * stride_do_h + (offs_m[:, None] * stride_do_m + offs_d[None, :]) ) do_t_o_ptrs = ( DO_T_O + off_b * stride_do_t_o_b + off_h * stride_do_t_o_h + offs_m[:, None] ) lse_ptrs = ( LSE + off_b * stride_lse_b + off_h * stride_lse_h + offs_m[:, None] ) ### load q, do, do_t_o, lse #### if EVEN_M: if EVEN_HEADDIM: q = tl.load( q_ptrs ) do = tl.load( do_ptrs ) else: q = tl.load( q_ptrs, mask=offs_d[None, :] < headdim, other=0.0 ) do = tl.load( do_ptrs, mask=offs_d[None, :] < headdim, other=0.0 ) do_t_o = tl.load( do_t_o_ptrs ) lse = tl.load( lse_ptrs ) else: if EVEN_HEADDIM: q = tl.load( q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0 ) do = tl.load( do_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0 ) else: q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) do = tl.load( do_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) do_t_o = tl.load( do_t_o_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0 ) lse = tl.load( lse_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0 ) lse = tl.where(lse == float("-inf"), 0.0, lse) lse *= 1.4426950408889634 # log2(e) qk_scale = softmax_scale qk_scale *= 1.4426950408889634 # log2(e) if MASK_TYPE == 1: window_mask_ptrs = ( WindowMask + off_b * stride_window_mask_b + (offs_m[:, None] * stride_window_mask_m + offs_n[None, :]) ) if EMPTY_RFA_KV == 0: chunk_mask_ptrs = ( ChunkMask + off_b * stride_chunk_mask_b + (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :]) ) dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # loop over k, v and update accumulator # Iterate over local singletons; # so we only iterate over blocks within the current window start_idx_n = offs_w * WINDOW_SIZE end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(start_idx_n, end_idx_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) # Trying to combine the two masks seem to make the result wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if MASK_TYPE == 1: if EVEN_M & EVEN_W: window_mask = tl.load( window_mask_ptrs + start_n - start_idx_n ) else: window_mask = tl.load( window_mask_ptrs + start_n - start_idx_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE), other=1, ) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. # we assume mask already implies the causal masking qk = qk * qk_scale qk = tl.where(window_mask, float("-inf"), qk) p = tl.exp2(qk - lse) else: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) p = tl.exp2(qk * qk_scale - lse) if EVEN_N & EVEN_M: if EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) dp = tl.dot(do, tl.trans(v)) ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) dq += tl.dot(ds, k) if EMPTY_RFA_KV == 0: # Iterate over RFA chunks # we only iterate over chunks before the current local singleton window end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks) for start_c in range(0, end_idx_c, BLOCK_N): start_c = tl.multiple_of(start_c, BLOCK_N) # -- compute qk ---- if EVEN_C & EVEN_M: if EVEN_HEADDIM: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc ) else: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc, mask=(start_c + offs_c)[:, None] < nchunks, other=0.0, ) else: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc, mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(rfa_k)) # Trying to combine the two masks seem to make the result wrong if not EVEN_C: # Need to mask out otherwise the softmax is wrong qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf")) if MASK_TYPE == 1: if EVEN_C & EVEN_M: chunk_mask = tl.load( chunk_mask_ptrs + start_c ) else: chunk_mask = tl.load( chunk_mask_ptrs + start_c, mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks), other=1, ) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. # we assume mask already implies the causal masking qk = qk * qk_scale qk = tl.where(chunk_mask, float("-inf"), qk) p = tl.exp2(qk - lse) else: p = tl.exp2(qk * qk_scale - lse) if EVEN_C & EVEN_M: if EVEN_HEADDIM: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc ) else: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc, mask=(start_c + offs_n)[:, None] < nchunks, other=0.0, ) else: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc, mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0, ) dp = tl.dot(do, tl.trans(rfa_v)) ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) dq += tl.dot(ds, rfa_k) start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) dq_ptrs = ( DQ + off_b * stride_dq_b + off_h * stride_dq_h + (offs_m[:, None] * stride_dq_m + offs_d[None, :]) ) if EVEN_M: if EVEN_HEADDIM: tl.store( dq_ptrs, dq ) else: tl.store( dq_ptrs, dq, mask=offs_d[None, :] < headdim ) else: if EVEN_HEADDIM: tl.store( dq_ptrs, dq, mask=offs_m[:, None] < seqlen_q ) else: tl.store( dq_ptrs, dq, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) ) _capability_90_config = { "fwd": { (torch.bfloat16, 64): (128, 128, 4, 3), (torch.bfloat16, 128): (128, 128, 8, 3), (torch.float32, 64): (128, 64, 8, 3), (torch.float32, 128): (64, 32, 4, 3), }, "bwd_dq": { (torch.bfloat16, 64): (128, 64, 4, 3), (torch.bfloat16, 128): (128, 64, 8, 3), (torch.float32, 64): (128, 64, 8, 2), (torch.float32, 128): (32, 32, 4, 2), }, "bwd_dkdv": { (torch.bfloat16, 64): (128, 64, 4, 2), (torch.bfloat16, 128): (128, 64, 8, 2), (torch.float32, 64): (128, 64, 8, 2), (torch.float32, 128): (32, 32, 4, 1), }, "bwd_drfa_kv": { (torch.bfloat16, 64): (128, 64, 4, 2), (torch.bfloat16, 128): (128, 64, 8, 2), (torch.float32, 64): (128, 64, 8, 2), (torch.float32, 128): (32, 32, 4, 1), } } _capability_80_config = { "fwd": { (torch.bfloat16, 64): (64, 64, 4, 3), (torch.bfloat16, 128): (64, 64, 8, 3), (torch.float32, 64): (64, 32, 4, 2), (torch.float32, 128): (64, 32, 8, 1), }, "bwd_dq": { (torch.bfloat16, 64): (64, 64, 4, 3), (torch.bfloat16, 128): (64, 32, 4, 2), (torch.float32, 64): (32, 32, 4, 2), (torch.float32, 128): (32, 32, 4, 2), }, "bwd_dkdv": { (torch.bfloat16, 64): (64, 64, 4, 3), (torch.bfloat16, 128): (32, 32, 4, 2), (torch.float32, 64): (32, 32, 4, 1), (torch.float32, 128): (16, 64, 8, 1), }, "bwd_drfa_kv": { (torch.bfloat16, 64): (64, 64, 4, 3), (torch.bfloat16, 128): (64, 32, 4, 3), (torch.float32, 64): (32, 32, 4, 1), (torch.float32, 128): (32, 32, 4, 1), } } def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]: capability = torch.cuda.get_device_capability() if capability >= (9, 0): kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1)) elif capability >= (8, 0): kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1)) else: if mode == "fwd": if dtype == torch.float32: kernel_config = (32, 16, 4, 2) else: kernel_config = (64, 32, 4, 2) else: if dtype == torch.float32: kernel_config = (16, 16, 4, 1) else: kernel_config = (32, 32, 4, 1) return kernel_config @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _fwd_eva_agg_kernel( Q, K, V, RFA_K, RFA_V, WindowMask, ChunkMask, Out, LSE, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, stride_window_mask_b, stride_window_mask_m, stride_chunk_mask_b, stride_chunk_mask_m, stride_ob, stride_oh, stride_om, stride_lse_b, stride_lse_h, nheads, seqlen_q, seqlen_k, nchunks, headdim, CHUNKS_PER_WINDOW: tl.constexpr, WINDOW_SIZE: tl.constexpr, MASK_TYPE: tl.constexpr, EMPTY_RFA_KV: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_W: tl.constexpr, EVEN_C: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_bh = tl.program_id(1) off_h = off_bh % nheads off_b = off_bh // nheads # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_w = (start_m * BLOCK_M) // WINDOW_SIZE offs_n = tl.arange(0, BLOCK_N) offs_c = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # TODO: add paratheses or not q_ptrs = ( Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) ) k_ptrs = ( K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) ) v_ptrs = ( V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) ) if EMPTY_RFA_KV == 0: rfa_k_ptrs = ( RFA_K + off_b * stride_rfa_kb + off_h * stride_rfa_kh + (offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) ) rfa_v_ptrs = ( RFA_V + off_b * stride_rfa_vb + off_h * stride_rfa_vh + (offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) ) qk_scale = softmax_scale qk_scale *= 1.4426950408889634 # log2(e) if MASK_TYPE == 1: window_mask_ptrs = ( WindowMask + off_b * stride_window_mask_b + (offs_m[:, None] * stride_window_mask_m + offs_n[None, :]) ) if EMPTY_RFA_KV == 0: chunk_mask_ptrs = ( ChunkMask + off_b * stride_chunk_mask_b + (offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :]) ) m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") d_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # load q: it will stay in SRAM throughout # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call # tl.load(q_ptrs), we get the wrong output! if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load( q_ptrs ) else: q = tl.load( q_ptrs, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: q = tl.load( q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0 ) else: q = tl.load( q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 ) # loop over k, v and update accumulator # Iterate over local singletons; # so we only iterate over blocks within the current window start_idx_n = offs_w * WINDOW_SIZE end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(start_idx_n, end_idx_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: k = tl.load( k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) # Trying to combine the two masks seem to make the result wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if MASK_TYPE == 1: if EVEN_M & EVEN_W: window_mask = tl.load( window_mask_ptrs + start_n - start_idx_n ) else: window_mask = tl.load( window_mask_ptrs + start_n - start_idx_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE), other=1, ) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. # we assume mask already implies the causal masking qk = qk * qk_scale qk = tl.where(window_mask, float("-inf"), qk) m_ij = tl.maximum(tl.max(qk, 1), m_i) masked_out_rows = (m_ij == float("-inf")) m_ij_masked = tl.where(masked_out_rows, 0, m_ij) p = tl.exp2(qk - m_ij_masked[:, None]) else: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i) masked_out_rows = (m_ij == float("-inf")) m_ij_masked = tl.where(masked_out_rows, 0, m_ij) p = tl.exp2(qk * qk_scale - m_ij_masked[:, None]) d_ij = tl.sum(p, 1) # scale acc_o prev_scale = tl.exp2(m_i - m_ij_masked) # # -- update output accumulator -- acc_o = acc_o * prev_scale[:, None] # update acc_o if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: v = tl.load( v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0, ) p = p.to(v.dtype) acc_o = tl.dot(p, v, acc_o) # -- update statistics d_i = d_i * prev_scale + d_ij m_i = m_ij if EMPTY_RFA_KV == 0: # Iterate over RFA chunks # we only iterate over chunks before the current local singleton window end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks) for start_c in range(0, end_idx_c, BLOCK_N): start_c = tl.multiple_of(start_c, BLOCK_N) # -- compute qk ---- if EVEN_C & EVEN_M: if EVEN_HEADDIM: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc ) else: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc, mask=(start_c + offs_c)[:, None] < nchunks, other=0.0, ) else: rfa_k = tl.load( rfa_k_ptrs + start_c * stride_rfa_kc, mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0, ) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(rfa_k)) # Trying to combine the two masks seem to make the result wrong if not EVEN_C: # Need to mask out otherwise the softmax is wrong qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf")) if MASK_TYPE == 1: if EVEN_C & EVEN_M: chunk_mask = tl.load( chunk_mask_ptrs + start_c ) else: chunk_mask = tl.load( chunk_mask_ptrs + start_c, mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks), other=1, ) # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler # can then fuse the mult and add into an fma instruction. But if we have bias we need to # to multiply with softmax_scale here. # we assume mask already implies the causal masking qk = qk * qk_scale qk = tl.where(chunk_mask, float("-inf"), qk) m_ij = tl.maximum(tl.max(qk, 1), m_i) masked_out_rows = (m_ij == float("-inf")) m_ij_masked = tl.where(masked_out_rows, 0, m_ij) p = tl.exp2(qk - m_ij_masked[:, None]) else: m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i) masked_out_rows = (m_ij == float("-inf")) m_ij_masked = tl.where(masked_out_rows, 0, m_ij) p = tl.exp2(qk * qk_scale - m_ij_masked[:, None]) d_ij = tl.sum(p, 1) # scale acc_o prev_scale = tl.exp2(m_i - m_ij_masked) # # -- update output accumulator -- acc_o = acc_o * prev_scale[:, None] # update acc_o # TODO: If we just do "if EVEN_N", there seems to be some race condition ? if EVEN_C & EVEN_M: if EVEN_HEADDIM: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc ) else: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc, mask=offs_d[None, :] < headdim, other=0.0 ) else: if EVEN_HEADDIM: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc, mask=(start_c + offs_n)[:, None] < nchunks, other=0.0, ) else: rfa_v = tl.load( rfa_v_ptrs + start_c * stride_rfa_vc, mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0, ) p = p.to(rfa_v.dtype) acc_o = tl.dot(p, rfa_v, acc_o) # -- update statistics d_i = d_i * prev_scale + d_ij m_i = m_ij # for rows that are all -inf, set d_i to 1.0 d_i = tl.where(d_i == 0.0, 1.0, d_i) # multiply by log(2) lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453 acc_o = acc_o / d_i[:, None] # TODO: understand why rematerialize offsets to save registers? start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) out_ptrs = ( Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) ) if EVEN_M: if EVEN_HEADDIM: tl.store( out_ptrs, acc_o ) else: tl.store( out_ptrs, acc_o, mask=offs_d[None, :] < headdim ) else: if EVEN_HEADDIM: tl.store( out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q ) else: tl.store( out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) ) lse_ptrs = ( LSE + off_b * stride_lse_b + off_h * stride_lse_h + offs_m ) if EVEN_M: tl.store( lse_ptrs, lse_m, ) else: tl.store( lse_ptrs, lse_m, mask=offs_m < seqlen_q ) def triton_eva_agg_fwd( q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window ): if rfa_k is None and rfa_v is None: empty_rfa_kv = 1 q, k, v = [ x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v] ] else: assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time." empty_rfa_kv = 0 q, k, v, rfa_k, rfa_v = [ x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, rfa_k, rfa_v] ] # shape constraints batch, nheads, seqlen_q, head_dim = q.shape _, _, seqlen_k, _ = k.shape if empty_rfa_kv == 0: nchunks = rfa_k.shape[-2] assert rfa_k.shape == (batch, nheads, nchunks, head_dim) assert rfa_v.shape == (batch, nheads, nchunks, head_dim) assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype else: nchunks = 0 assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert k.shape == (batch, nheads, seqlen_k, head_dim) assert v.shape == (batch, nheads, seqlen_k, head_dim) assert head_dim <= 128, "We only test head dimensions up to 128" # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now" assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) mask_type = 0 if window_mask is not None: mask_type = 1 assert window_mask.dtype == torch.bool assert window_mask.is_cuda assert window_mask.dim() == 4 assert window_mask.shape == (batch, 1, seqlen_q, window_size) if window_mask.stride(-1) != 1: window_mask = window_mask.contiguous() assert chunk_mask is not None assert chunk_mask.dtype == torch.bool assert chunk_mask.is_cuda assert chunk_mask.dim() == 4 assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks) if chunk_mask.stride(-1) != 1: chunk_mask = chunk_mask.contiguous() chunk_mask_strides = ( (chunk_mask.stride(0), chunk_mask.stride(2)) if mask_type == 1 else (0, 0) ) window_mask_strides = ( (window_mask.stride(0), window_mask.stride(2)) if mask_type == 1 else (0, 0) ) rfa_k_strides = ( (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2)) if empty_rfa_kv == 0 else (0, 0, 0) ) rfa_v_strides = ( (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2)) if empty_rfa_kv == 0 else (0, 0, 0) ) o = torch.empty_like(q) lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd") assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK" assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N" grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _fwd_eva_agg_kernel[grid]( q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, o, lse, softmax_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], window_mask_strides[0], window_mask_strides[1], chunk_mask_strides[0], chunk_mask_strides[1], o.stride(0), o.stride(1), o.stride(2), lse.stride(0), lse.stride(1), nheads, seqlen_q, seqlen_k, nchunks, head_dim, chunks_per_window, window_size, mask_type, empty_rfa_kv, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) return o, lse def triton_eva_agg_bwd( do, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, o, lse, dq, dk, dv, d_rfa_k, d_rfa_v, softmax_scale, window_size, chunks_per_window, empty_rfa_kv, mask_type, ): if do.stride(-1) != 1: do = do.contiguous() # shape constraints batch, nheads, seqlen_q, head_dim = q.shape _, _, seqlen_k, _ = k.shape if empty_rfa_kv == 0: nchunks = rfa_k.shape[-2] assert rfa_k.shape == (batch, nheads, nchunks, head_dim) assert rfa_v.shape == (batch, nheads, nchunks, head_dim) assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1 assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype else: nchunks = 0 assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" assert lse.shape == (batch, nheads, seqlen_q) assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1 assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) assert head_dim <= 128, "We only test head dimensions up to 128" window_mask_strides = ( (window_mask.stride(0), window_mask.stride(2)) if mask_type == 1 else (0, 0) ) chunk_mask_strides = ( (chunk_mask.stride(0), chunk_mask.stride(2)) if mask_type == 1 else (0, 0) ) rfa_k_strides = ( (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2)) if empty_rfa_kv == 0 else (0, 0, 0) ) rfa_v_strides = ( (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2)) if empty_rfa_kv == 0 else (0, 0, 0) ) d_rfa_k_strides = ( (d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2)) if empty_rfa_kv == 0 else (0, 0, 0) ) d_rfa_v_strides = ( (d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2)) if empty_rfa_kv == 0 else (0, 0, 0) ) BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype) BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq") assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK" assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK" grid = lambda META: ( triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads, ) _bwd_eva_agg_kernel_dq[grid]( q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, do, lse, do_t_o, dq, softmax_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], window_mask_strides[0], window_mask_strides[1], chunk_mask_strides[0], chunk_mask_strides[1], do.stride(0), do.stride(1), do.stride(2), lse.stride(0), lse.stride(1), do_t_o.stride(0), do_t_o.stride(1), dq.stride(0), dq.stride(1), dq.stride(2), nheads, seqlen_q, seqlen_k, nchunks, head_dim, chunks_per_window, window_size, mask_type, empty_rfa_kv, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv") grid = lambda META: ( triton.cdiv(seqlen_k, META["BLOCK_N"]), batch * nheads, ) _bwd_eva_agg_kernel_dkdv[grid]( q, k, v, window_mask, do, lse, do_t_o, dk, dv, softmax_scale, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), window_mask_strides[0], window_mask_strides[1], do.stride(0), do.stride(1), do.stride(2), lse.stride(0), lse.stride(1), do_t_o.stride(0), do_t_o.stride(1), dk.stride(0), dk.stride(1), dk.stride(2), dv.stride(0), dv.stride(1), dv.stride(2), nheads, seqlen_q, seqlen_k, head_dim, window_size, mask_type, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) if empty_rfa_kv == 0: BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv") grid = lambda META: ( triton.cdiv(nchunks, META["BLOCK_N"]), batch * nheads, ) _bwd_eva_agg_kernel_drfa_kv[grid]( q, rfa_k, rfa_v, chunk_mask, do, lse, do_t_o, d_rfa_k, d_rfa_v, softmax_scale, q.stride(0), q.stride(1), q.stride(2), rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], chunk_mask_strides[0], chunk_mask_strides[1], do.stride(0), do.stride(1), do.stride(2), lse.stride(0), lse.stride(1), do_t_o.stride(0), do_t_o.stride(1), d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2], d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2], nheads, seqlen_q, nchunks, head_dim, chunks_per_window, window_size, mask_type, BLOCK_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=num_warps, num_stages=num_stages, ) class EvaAggFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None): if rfa_k is None and rfa_v is None: empty_rfa_kv = 1 else: assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time." empty_rfa_kv = 0 if window_mask is not None: mask_type = 1 else: mask_type = 0 o, lse = triton_eva_agg_fwd( q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window ) ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask) ctx.softmax_scale = softmax_scale ctx.window_size = window_size ctx.chunks_per_window = chunks_per_window ctx.empty_rfa_kv = empty_rfa_kv ctx.mask_type = mask_type return o @staticmethod def backward(ctx, do): q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) if ctx.empty_rfa_kv == 0: d_rfa_k = torch.empty_like(rfa_k) d_rfa_v = torch.empty_like(rfa_v) else: d_rfa_k = None d_rfa_v = None triton_eva_agg_bwd( do, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, o, lse, dq, dk, dv, d_rfa_k, d_rfa_v, softmax_scale=ctx.softmax_scale, window_size=ctx.window_size, chunks_per_window=ctx.chunks_per_window, empty_rfa_kv=ctx.empty_rfa_kv, mask_type=ctx.mask_type, ) return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None def eva_agg_func_triton( q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None, ): return EvaAggFunc.apply( q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window, )