# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from typing import Optional, Tuple import torch import triton import triton.language as tl from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from fla.ops.utils.solve_tril import solve_tril from fla.utils import check_shared_mem, is_nvidia_hopper NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], ) @triton.jit(do_not_specialize=['T']) def fwd_recompute_w_u_kernel( k, v, beta, w, u, A, offsets, indices, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) else: p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.load(p_A, boundary_check=(0, 1)) for i_v in range(tl.cdiv(V, BV)): if HEAD_FIRST: p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_u = tl.make_block_ptr(u + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) else: p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_u = tl.make_block_ptr(u + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) b_u = tl.dot(b_A.to(b_vb.dtype), b_vb, allow_tf32=False) tl.store(p_u, (b_u).to(p_u.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_w = tl.make_block_ptr(w + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None]).to(b_k.dtype) b_w = tl.dot(b_A.to(b_kb.dtype), b_kb, allow_tf32=False) tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics({ 'USE_OFFSETS': lambda args: args['offsets'] is not None }) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in NUM_WARPS for num_stages in [2, 3, 4] ], key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'HEAD_FIRST', 'USE_OFFSETS'], ) @triton.jit(do_not_specialize=['T']) def bwd_prepare_wy_repr_kernel( k, v, beta, A, dw, du, dk, dv, dbeta, offsets, indices, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, HEAD_FIRST: tl.constexpr, USE_OFFSETS: tl.constexpr ): i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if USE_OFFSETS: i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if HEAD_FIRST: p_beta = tl.make_block_ptr(beta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + i_bh * T*BT, (BT, T), (1, BT), (0, i_t * BT), (BT, BT), (0, 1)) else: p_beta = tl.make_block_ptr(beta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) p_A = tl.make_block_ptr(A + (bos*H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) b_beta = tl.load(p_beta, boundary_check=(0,)) b_A = tl.load(p_A, boundary_check=(0, 1)) b_dbeta = tl.zeros([BT], dtype=tl.float32) b_dA = tl.zeros([BT, BT], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): if HEAD_FIRST: p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_du = tl.make_block_ptr(du + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) else: p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_dv = tl.make_block_ptr(dv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) p_du = tl.make_block_ptr(du + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) b_v = tl.load(p_v, boundary_check=(0, 1)) b_v_beta = (b_v * b_beta[:, None]).to(b_v.dtype) b_du = tl.load(p_du, boundary_check=(0, 1)) b_dA += tl.dot(b_du, tl.trans(b_v_beta), allow_tf32=False) b_dv_beta = tl.dot(b_A, b_du, allow_tf32=False) b_dv = b_dv_beta * b_beta[:, None] b_dbeta += tl.sum(b_dv_beta * b_v, 1) tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dw = tl.make_block_ptr(dw + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dw = tl.make_block_ptr(dw + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) b_dw = tl.load(p_dw, boundary_check=(0, 1)) b_dA += tl.dot(b_dw, tl.trans(b_k_beta), allow_tf32=False) b_dk_beta = tl.dot(b_A, b_dw, allow_tf32=False) b_dk = b_dk_beta * b_beta[:, None] b_dbeta += tl.sum(b_dk_beta * b_k, 1) tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], b_dA, 0) b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) b_dA = tl.where(tl.arange(0, BT)[:, None] > tl.arange(0, BT)[None, :], -b_dA, 0).to(k.dtype.element_ty) for i_k in range(tl.cdiv(K, BK)): if HEAD_FIRST: p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) else: p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) p_dk = tl.make_block_ptr(dk + (bos*H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_dk = tl.load(p_dk, boundary_check=(0, 1)) b_k_beta = (b_k * b_beta[:, None]).to(b_k.dtype) b_dk_beta = tl.dot(b_dA, b_k, allow_tf32=False) b_dbeta += tl.sum(b_dk_beta * b_k, 1) b_dk += tl.dot(tl.trans(b_dA), b_k_beta, allow_tf32=False) b_dk += b_dk_beta * b_beta[:, None] tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) if HEAD_FIRST: p_dbeta = tl.make_block_ptr(dbeta + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) else: p_dbeta = tl.make_block_ptr(dbeta + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) tl.store(p_dbeta, b_dbeta.to(p_dbeta.dtype.element_ty), boundary_check=(0,)) def fwd_prepare_wy_repr( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool = False, chunk_size: int = 64 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A = chunk_scaled_dot_kkt_fwd( k=k, beta=beta, cu_seqlens=offsets, head_first=head_first, chunk_size=chunk_size, output_dtype=torch.float32 ) A = solve_tril( A=A, cu_seqlens=offsets, head_first=head_first, output_dtype=k.dtype ) w, u = fwd_recompute_w_u( k=k, v=v, beta=beta, A=A, offsets=offsets, indices=indices, head_first=head_first, chunk_size=chunk_size ) return w, u, A def fwd_recompute_w_u( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool, chunk_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, v.shape[-1] else: B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) CONST_TILING = 64 if check_shared_mem() else 32 BK = min(triton.next_power_of_2(K), CONST_TILING) BV = min(triton.next_power_of_2(V), CONST_TILING) NT = triton.cdiv(T, BT) if offsets is None else len(indices) u = torch.empty_like(v) w = torch.empty_like(k) fwd_recompute_w_u_kernel[(NT, B*H)]( k, v, beta, w, u, A, offsets=offsets, indices=indices, T=T, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, HEAD_FIRST=head_first ) return w, u def bwd_prepare_wy_repr( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, dw: torch.Tensor, du: torch.Tensor, offsets: Optional[torch.LongTensor], indices: Optional[torch.LongTensor], head_first: bool, chunk_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if head_first: B, H, T, K, V = *k.shape, v.shape[-1] else: B, T, H, K, V = *k.shape, v.shape[-1] BT = min(chunk_size, max(triton.next_power_of_2(T), 16)) CONST_TILING = 64 if check_shared_mem() else 32 BK = min(triton.next_power_of_2(K), CONST_TILING) BV = min(triton.next_power_of_2(V), CONST_TILING) NT = triton.cdiv(T, BT) if offsets is None else len(indices) dk = torch.empty_like(k) dv = torch.empty_like(v) dbeta = torch.empty_like(beta) bwd_prepare_wy_repr_kernel[(NT, B * H)]( k, v, beta, A, dw, du, dk, dv, dbeta, offsets=offsets, indices=indices, T=T, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, HEAD_FIRST=head_first ) return dk, dv, dbeta