|
|
|
import math |
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
@triton.heuristics( |
|
{ |
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, |
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, |
|
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, |
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
|
} |
|
) |
|
@triton.jit |
|
def _bwd_eva_agg_kernel_dkdv( |
|
Q, |
|
K, |
|
V, |
|
WindowMask, |
|
DO, |
|
LSE, |
|
DO_T_O, |
|
DK, |
|
DV, |
|
softmax_scale, |
|
stride_qb, stride_qh, stride_qm, |
|
stride_kb, stride_kh, stride_kn, |
|
stride_vb, stride_vh, stride_vn, |
|
stride_window_mask_b, stride_window_mask_m, |
|
stride_do_b, stride_do_h, stride_do_m, |
|
stride_lse_b, stride_lse_h, |
|
stride_do_t_o_b, stride_do_t_o_h, |
|
stride_dk_b, stride_dk_h, stride_dk_n, |
|
stride_dv_b, stride_dv_h, stride_dv_n, |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
headdim, |
|
WINDOW_SIZE: tl.constexpr, |
|
MASK_TYPE: tl.constexpr, |
|
BLOCK_HEADDIM: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_N: tl.constexpr, |
|
EVEN_W: tl.constexpr, |
|
EVEN_HEADDIM: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
off_bh = tl.program_id(1) |
|
off_h = off_bh % nheads |
|
off_b = off_bh // nheads |
|
|
|
start_n = tl.program_id(0) |
|
|
|
offs_w = (start_n * BLOCK_N) // WINDOW_SIZE |
|
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) |
|
offs_m = tl.arange(0, BLOCK_M) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
|
|
|
q_ptrs = ( |
|
Q + |
|
off_b * stride_qb + |
|
off_h * stride_qh + |
|
offs_m[:, None] * stride_qm + offs_d[None, :] |
|
) |
|
k_ptrs = ( |
|
K + |
|
off_b * stride_kb + |
|
off_h * stride_kh + |
|
offs_n[:, None] * stride_kn + offs_d[None, :] |
|
) |
|
v_ptrs = ( |
|
V + |
|
off_b * stride_vb + |
|
off_h * stride_vh + |
|
offs_n[:, None] * stride_vn + offs_d[None, :] |
|
) |
|
do_ptrs = ( |
|
DO + |
|
off_b * stride_do_b + |
|
off_h * stride_do_h + |
|
offs_m[:, None] * stride_do_m + offs_d[None, :] |
|
) |
|
do_t_o_ptrs = ( |
|
DO_T_O + |
|
off_b * stride_do_t_o_b + |
|
off_h * stride_do_t_o_h + |
|
offs_m[:, None] |
|
) |
|
lse_ptrs = ( |
|
LSE + |
|
off_b * stride_lse_b + |
|
off_h * stride_lse_h + |
|
offs_m[:, None] |
|
) |
|
if MASK_TYPE == 1: |
|
m_ptrs = ( |
|
WindowMask + |
|
off_b * stride_window_mask_b + |
|
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :]) |
|
) |
|
dk_ptrs = ( |
|
DK + |
|
off_b * stride_dk_b + |
|
off_h * stride_dk_h + |
|
offs_n[:, None] * stride_dk_n + offs_d[None, :] |
|
) |
|
dv_ptrs = ( |
|
DV + |
|
off_b * stride_dv_b + |
|
off_h * stride_dv_h + |
|
offs_n[:, None] * stride_dv_n + offs_d[None, :] |
|
) |
|
|
|
|
|
|
|
begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M |
|
end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q) |
|
|
|
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) |
|
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) |
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
k = tl.load(k_ptrs) |
|
v = tl.load(v_ptrs) |
|
else: |
|
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) |
|
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) |
|
else: |
|
if EVEN_HEADDIM: |
|
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) |
|
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) |
|
else: |
|
k = tl.load( |
|
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 |
|
) |
|
v = tl.load( |
|
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 |
|
) |
|
for start_m in range(begin_m, end_m, BLOCK_M): |
|
start_m = tl.multiple_of(start_m, BLOCK_M) |
|
|
|
if EVEN_M & EVEN_N: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
do_t_o = tl.load( |
|
do_t_o_ptrs + start_m |
|
) |
|
lse = tl.load( |
|
lse_ptrs + start_m |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm, |
|
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m, |
|
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
do_t_o = tl.load( |
|
do_t_o_ptrs + start_m, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
lse = tl.load( |
|
lse_ptrs + start_m, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
lse = tl.where(lse == float("-inf"), 0.0, lse) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(k)) |
|
if not EVEN_M: |
|
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_M & EVEN_W: |
|
mask = tl.load( |
|
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) |
|
) |
|
else: |
|
mask = tl.load( |
|
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE), |
|
mask=((start_m + offs_m)[:, None] < seqlen_q) |
|
& (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE), |
|
other=1, |
|
) |
|
|
|
|
|
|
|
|
|
qk = qk * softmax_scale |
|
qk = tl.where(mask, float("-inf"), qk) |
|
p = tl.exp(qk - lse) |
|
else: |
|
qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf")) |
|
p = tl.exp(qk * softmax_scale - lse) |
|
|
|
|
|
dp = tl.dot(do, tl.trans(v)) |
|
|
|
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) |
|
|
|
|
|
dv += tl.dot(tl.trans(p.to(do.dtype)), do) |
|
|
|
dk += tl.dot(tl.trans(ds), q) |
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
tl.store(dv_ptrs, dv) |
|
tl.store(dk_ptrs, dk) |
|
else: |
|
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) |
|
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) |
|
else: |
|
if EVEN_HEADDIM: |
|
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) |
|
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) |
|
else: |
|
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) |
|
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) |
|
|
|
@triton.heuristics( |
|
{ |
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, |
|
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, |
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
|
} |
|
) |
|
@triton.jit |
|
def _bwd_eva_agg_kernel_drfa_kv( |
|
Q, |
|
RFA_K, |
|
RFA_V, |
|
ChunkMask, |
|
DO, |
|
LSE, |
|
DO_T_O, |
|
D_RFA_K, |
|
D_RFA_V, |
|
softmax_scale, |
|
stride_qb, stride_qh, stride_qm, |
|
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, |
|
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, |
|
stride_chunk_mask_b, stride_chunk_mask_m, |
|
stride_do_b, stride_do_h, stride_do_m, |
|
stride_lse_b, stride_lse_h, |
|
stride_do_t_o_b, stride_do_t_o_h, |
|
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c, |
|
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c, |
|
nheads, |
|
seqlen_q, |
|
nchunks, |
|
headdim, |
|
CHUNKS_PER_WINDOW: tl.constexpr, |
|
WINDOW_SIZE: tl.constexpr, |
|
MASK_TYPE: tl.constexpr, |
|
BLOCK_HEADDIM: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_C: tl.constexpr, |
|
EVEN_HEADDIM: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
off_bh = tl.program_id(1) |
|
off_h = off_bh % nheads |
|
off_b = off_bh // nheads |
|
start_c = tl.program_id(0) |
|
|
|
offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N) |
|
|
|
offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW |
|
offs_m = tl.arange(0, BLOCK_M) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
|
|
|
q_ptrs = ( |
|
Q + |
|
off_b * stride_qb + |
|
off_h * stride_qh + |
|
(offs_m[:, None] * stride_qm + offs_d[None, :]) |
|
) |
|
do_ptrs = ( |
|
DO + |
|
off_b * stride_do_b + |
|
off_h * stride_do_h + |
|
(offs_m[:, None] * stride_do_m + offs_d[None, :]) |
|
) |
|
do_t_o_ptrs = ( |
|
DO_T_O + |
|
off_b * stride_do_t_o_b + |
|
off_h * stride_do_t_o_h + |
|
(offs_m[:, None]) |
|
) |
|
lse_ptrs = ( |
|
LSE + |
|
off_b * stride_lse_b + |
|
off_h * stride_lse_h + |
|
(offs_m[:, None]) |
|
) |
|
rfa_k_ptrs = ( |
|
RFA_K + |
|
off_b * stride_rfa_kb + |
|
off_h * stride_rfa_kh + |
|
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) |
|
) |
|
rfa_v_ptrs = ( |
|
RFA_V + |
|
off_b * stride_rfa_vb + |
|
off_h * stride_rfa_vh + |
|
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) |
|
) |
|
if MASK_TYPE == 1: |
|
rfa_m_ptrs = ( |
|
ChunkMask + |
|
off_b * stride_chunk_mask_b + |
|
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :]) |
|
) |
|
d_rfa_k_ptrs = ( |
|
D_RFA_K + |
|
off_b * stride_d_rfa_k_b + |
|
off_h * stride_d_rfa_k_h + |
|
(offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :]) |
|
) |
|
d_rfa_v_ptrs = ( |
|
D_RFA_V + |
|
off_b * stride_d_rfa_v_b + |
|
off_h * stride_d_rfa_v_h + |
|
(offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :]) |
|
) |
|
|
|
d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) |
|
d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) |
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load(rfa_k_ptrs) |
|
rfa_v = tl.load(rfa_v_ptrs) |
|
else: |
|
rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) |
|
rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0) |
|
rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0 |
|
) |
|
rfa_v = tl.load( |
|
rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0 |
|
) |
|
begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q) |
|
end_m = seqlen_q |
|
for start_m in range(begin_m, end_m, BLOCK_M): |
|
start_m = tl.multiple_of(start_m, BLOCK_M) |
|
|
|
if EVEN_M: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
do_t_o = tl.load( |
|
do_t_o_ptrs + start_m |
|
) |
|
lse = tl.load( |
|
lse_ptrs + start_m |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs + start_m * stride_qm, |
|
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs + start_m * stride_do_m, |
|
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
do_t_o = tl.load( |
|
do_t_o_ptrs + start_m, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
lse = tl.load( |
|
lse_ptrs + start_m, |
|
mask=(start_m + offs_m)[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
lse = tl.where(lse == float("-inf"), 0.0, lse) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(rfa_k)) |
|
if not EVEN_M: |
|
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_M & EVEN_C: |
|
mask = tl.load( |
|
rfa_m_ptrs + (start_m * stride_chunk_mask_m) |
|
) |
|
else: |
|
mask = tl.load( |
|
rfa_m_ptrs + (start_m * stride_chunk_mask_m), |
|
mask=((start_m + offs_m)[:, None] < seqlen_q) |
|
& (offs_c[None, :] < nchunks), |
|
other=1, |
|
) |
|
|
|
|
|
|
|
|
|
qk = qk * softmax_scale |
|
qk = tl.where(mask, float("-inf"), qk) |
|
p = tl.exp(qk - lse) |
|
else: |
|
p = tl.exp(qk * softmax_scale - lse) |
|
|
|
dp = tl.dot(do, tl.trans(rfa_v)) |
|
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) |
|
|
|
d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do) |
|
|
|
d_rfa_k += tl.dot(tl.trans(ds), q) |
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
tl.store(d_rfa_v_ptrs, d_rfa_v) |
|
tl.store(d_rfa_k_ptrs, d_rfa_k) |
|
else: |
|
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim) |
|
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim) |
|
else: |
|
if EVEN_HEADDIM: |
|
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks) |
|
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks) |
|
else: |
|
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim)) |
|
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim)) |
|
|
|
@triton.heuristics( |
|
{ |
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, |
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, |
|
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, |
|
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, |
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
|
} |
|
) |
|
@triton.jit |
|
def _bwd_eva_agg_kernel_dq( |
|
Q, |
|
K, |
|
V, |
|
RFA_K, |
|
RFA_V, |
|
WindowMask, |
|
ChunkMask, |
|
DO, |
|
LSE, |
|
DO_T_O, |
|
DQ, |
|
softmax_scale, |
|
stride_qb, stride_qh, stride_qm, |
|
stride_kb, stride_kh, stride_kn, |
|
stride_vb, stride_vh, stride_vn, |
|
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, |
|
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, |
|
stride_window_mask_b, stride_window_mask_m, |
|
stride_chunk_mask_b, stride_chunk_mask_m, |
|
stride_do_b, stride_do_h, stride_do_m, |
|
stride_lse_b, stride_lse_h, |
|
stride_do_t_o_b, stride_do_t_o_h, |
|
stride_dq_b, stride_dq_h, stride_dq_m, |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
nchunks, |
|
headdim, |
|
CHUNKS_PER_WINDOW: tl.constexpr, |
|
WINDOW_SIZE: tl.constexpr, |
|
MASK_TYPE: tl.constexpr, |
|
EMPTY_RFA_KV: tl.constexpr, |
|
BLOCK_HEADDIM: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_N: tl.constexpr, |
|
EVEN_W: tl.constexpr, |
|
EVEN_C: tl.constexpr, |
|
EVEN_HEADDIM: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
start_m = tl.program_id(0) |
|
off_bh = tl.program_id(1) |
|
off_h = off_bh % nheads |
|
off_b = off_bh // nheads |
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_c = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
|
q_ptrs = ( |
|
Q + |
|
off_b * stride_qb + |
|
off_h * stride_qh + |
|
(offs_m[:, None] * stride_qm + offs_d[None, :]) |
|
) |
|
k_ptrs = ( |
|
K + |
|
off_b * stride_kb + |
|
off_h * stride_kh + |
|
(offs_n[:, None] * stride_kn + offs_d[None, :]) |
|
) |
|
v_ptrs = ( |
|
V + |
|
off_b * stride_vb + |
|
off_h * stride_vh + |
|
(offs_n[:, None] * stride_vn + offs_d[None, :]) |
|
) |
|
if EMPTY_RFA_KV == 0: |
|
rfa_k_ptrs = ( |
|
RFA_K + |
|
off_b * stride_rfa_kb + |
|
off_h * stride_rfa_kh + |
|
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) |
|
) |
|
rfa_v_ptrs = ( |
|
RFA_V + |
|
off_b * stride_rfa_vb + |
|
off_h * stride_rfa_vh + |
|
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) |
|
) |
|
dq_ptrs = ( |
|
DQ + |
|
off_b * stride_dq_b + |
|
off_h * stride_dq_h + |
|
(offs_m[:, None] * stride_dq_m + offs_d[None, :]) |
|
) |
|
do_ptrs = ( |
|
DO + |
|
off_b * stride_do_b + |
|
off_h * stride_do_h + |
|
(offs_m[:, None] * stride_do_m + offs_d[None, :]) |
|
) |
|
do_t_o_ptrs = ( |
|
DO_T_O + |
|
off_b * stride_do_t_o_b + |
|
off_h * stride_do_t_o_h + |
|
offs_m[:, None] |
|
) |
|
lse_ptrs = ( |
|
LSE + |
|
off_b * stride_lse_b + |
|
off_h * stride_lse_h + |
|
offs_m[:, None] |
|
) |
|
|
|
if EVEN_M: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs |
|
) |
|
do = tl.load( |
|
do_ptrs |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
do_t_o = tl.load( |
|
do_t_o_ptrs |
|
) |
|
lse = tl.load( |
|
lse_ptrs |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=offs_m[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs, |
|
mask=offs_m[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
do = tl.load( |
|
do_ptrs, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
do_t_o = tl.load( |
|
do_t_o_ptrs, |
|
mask=offs_m[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
lse = tl.load( |
|
lse_ptrs, |
|
mask=offs_m[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
lse = tl.where(lse == float("-inf"), 0.0, lse) |
|
lse *= 1.4426950408889634 |
|
qk_scale = softmax_scale |
|
qk_scale *= 1.4426950408889634 |
|
if MASK_TYPE == 1: |
|
window_mask_ptrs = ( |
|
WindowMask + |
|
off_b * stride_window_mask_b + |
|
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :]) |
|
) |
|
if EMPTY_RFA_KV == 0: |
|
chunk_mask_ptrs = ( |
|
ChunkMask + |
|
off_b * stride_chunk_mask_b + |
|
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :]) |
|
) |
|
|
|
dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) |
|
|
|
|
|
|
|
start_idx_n = offs_w * WINDOW_SIZE |
|
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) |
|
for start_n in range(start_idx_n, end_idx_n, BLOCK_N): |
|
start_n = tl.multiple_of(start_n, BLOCK_N) |
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn |
|
) |
|
else: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0, |
|
) |
|
else: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(k)) |
|
|
|
if not EVEN_N: |
|
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_M & EVEN_W: |
|
window_mask = tl.load( |
|
window_mask_ptrs + start_n - start_idx_n |
|
) |
|
else: |
|
window_mask = tl.load( |
|
window_mask_ptrs + start_n - start_idx_n, |
|
mask=(offs_m[:, None] < seqlen_q) |
|
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE), |
|
other=1, |
|
) |
|
|
|
|
|
|
|
|
|
qk = qk * qk_scale |
|
qk = tl.where(window_mask, float("-inf"), qk) |
|
p = tl.exp2(qk - lse) |
|
else: |
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) |
|
p = tl.exp2(qk * qk_scale - lse) |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn |
|
) |
|
else: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0, |
|
) |
|
else: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
dp = tl.dot(do, tl.trans(v)) |
|
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) |
|
dq += tl.dot(ds, k) |
|
|
|
if EMPTY_RFA_KV == 0: |
|
|
|
|
|
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks) |
|
for start_c in range(0, end_idx_c, BLOCK_N): |
|
start_c = tl.multiple_of(start_c, BLOCK_N) |
|
|
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc |
|
) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=(start_c + offs_c)[:, None] < nchunks, |
|
other=0.0, |
|
) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(rfa_k)) |
|
|
|
if not EVEN_C: |
|
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_C & EVEN_M: |
|
chunk_mask = tl.load( |
|
chunk_mask_ptrs + start_c |
|
) |
|
else: |
|
chunk_mask = tl.load( |
|
chunk_mask_ptrs + start_c, |
|
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks), |
|
other=1, |
|
) |
|
|
|
|
|
|
|
|
|
qk = qk * qk_scale |
|
qk = tl.where(chunk_mask, float("-inf"), qk) |
|
p = tl.exp2(qk - lse) |
|
else: |
|
p = tl.exp2(qk * qk_scale - lse) |
|
|
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc |
|
) |
|
else: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=(start_c + offs_n)[:, None] < nchunks, |
|
other=0.0, |
|
) |
|
else: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
dp = tl.dot(do, tl.trans(rfa_v)) |
|
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype) |
|
dq += tl.dot(ds, rfa_k) |
|
|
|
start_m = tl.program_id(0) |
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
dq_ptrs = ( |
|
DQ + |
|
off_b * stride_dq_b + |
|
off_h * stride_dq_h + |
|
(offs_m[:, None] * stride_dq_m + offs_d[None, :]) |
|
) |
|
if EVEN_M: |
|
if EVEN_HEADDIM: |
|
tl.store( |
|
dq_ptrs, dq |
|
) |
|
else: |
|
tl.store( |
|
dq_ptrs, dq, |
|
mask=offs_d[None, :] < headdim |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
tl.store( |
|
dq_ptrs, dq, |
|
mask=offs_m[:, None] < seqlen_q |
|
) |
|
else: |
|
tl.store( |
|
dq_ptrs, dq, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) |
|
) |
|
|
|
_capability_90_config = { |
|
"fwd": { |
|
(torch.bfloat16, 64): (128, 128, 4, 3), |
|
(torch.bfloat16, 128): (128, 128, 8, 3), |
|
(torch.float32, 64): (128, 64, 8, 3), |
|
(torch.float32, 128): (64, 32, 4, 3), |
|
}, |
|
"bwd_dq": { |
|
(torch.bfloat16, 64): (128, 64, 4, 3), |
|
(torch.bfloat16, 128): (128, 64, 8, 3), |
|
(torch.float32, 64): (128, 64, 8, 2), |
|
(torch.float32, 128): (32, 32, 4, 2), |
|
}, |
|
"bwd_dkdv": { |
|
(torch.bfloat16, 64): (128, 64, 4, 2), |
|
(torch.bfloat16, 128): (128, 64, 8, 2), |
|
(torch.float32, 64): (128, 64, 8, 2), |
|
(torch.float32, 128): (32, 32, 4, 1), |
|
}, |
|
"bwd_drfa_kv": { |
|
(torch.bfloat16, 64): (128, 64, 4, 2), |
|
(torch.bfloat16, 128): (128, 64, 8, 2), |
|
(torch.float32, 64): (128, 64, 8, 2), |
|
(torch.float32, 128): (32, 32, 4, 1), |
|
} |
|
} |
|
|
|
_capability_80_config = { |
|
"fwd": { |
|
(torch.bfloat16, 64): (64, 64, 4, 3), |
|
(torch.bfloat16, 128): (64, 64, 8, 3), |
|
(torch.float32, 64): (64, 32, 4, 2), |
|
(torch.float32, 128): (64, 32, 8, 1), |
|
}, |
|
"bwd_dq": { |
|
(torch.bfloat16, 64): (64, 64, 4, 3), |
|
(torch.bfloat16, 128): (64, 32, 4, 2), |
|
(torch.float32, 64): (32, 32, 4, 2), |
|
(torch.float32, 128): (32, 32, 4, 2), |
|
}, |
|
"bwd_dkdv": { |
|
(torch.bfloat16, 64): (64, 64, 4, 3), |
|
(torch.bfloat16, 128): (32, 32, 4, 2), |
|
(torch.float32, 64): (32, 32, 4, 1), |
|
(torch.float32, 128): (16, 64, 8, 1), |
|
}, |
|
"bwd_drfa_kv": { |
|
(torch.bfloat16, 64): (64, 64, 4, 3), |
|
(torch.bfloat16, 128): (64, 32, 4, 3), |
|
(torch.float32, 64): (32, 32, 4, 1), |
|
(torch.float32, 128): (32, 32, 4, 1), |
|
} |
|
} |
|
|
|
def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]: |
|
capability = torch.cuda.get_device_capability() |
|
if capability >= (9, 0): |
|
kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1)) |
|
elif capability >= (8, 0): |
|
kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1)) |
|
else: |
|
if mode == "fwd": |
|
if dtype == torch.float32: |
|
kernel_config = (32, 16, 4, 2) |
|
else: |
|
kernel_config = (64, 32, 4, 2) |
|
else: |
|
if dtype == torch.float32: |
|
kernel_config = (16, 16, 4, 1) |
|
else: |
|
kernel_config = (32, 32, 4, 1) |
|
return kernel_config |
|
|
|
@triton.heuristics( |
|
{ |
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, |
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, |
|
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0, |
|
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0, |
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], |
|
} |
|
) |
|
@triton.jit |
|
def _fwd_eva_agg_kernel( |
|
Q, |
|
K, |
|
V, |
|
RFA_K, |
|
RFA_V, |
|
WindowMask, |
|
ChunkMask, |
|
Out, |
|
LSE, |
|
softmax_scale, |
|
stride_qb, stride_qh, stride_qm, |
|
stride_kb, stride_kh, stride_kn, |
|
stride_vb, stride_vh, stride_vn, |
|
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc, |
|
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc, |
|
stride_window_mask_b, stride_window_mask_m, |
|
stride_chunk_mask_b, stride_chunk_mask_m, |
|
stride_ob, stride_oh, stride_om, |
|
stride_lse_b, stride_lse_h, |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
nchunks, |
|
headdim, |
|
CHUNKS_PER_WINDOW: tl.constexpr, |
|
WINDOW_SIZE: tl.constexpr, |
|
MASK_TYPE: tl.constexpr, |
|
EMPTY_RFA_KV: tl.constexpr, |
|
BLOCK_HEADDIM: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_N: tl.constexpr, |
|
EVEN_W: tl.constexpr, |
|
EVEN_C: tl.constexpr, |
|
EVEN_HEADDIM: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
start_m = tl.program_id(0) |
|
off_bh = tl.program_id(1) |
|
off_h = off_bh % nheads |
|
off_b = off_bh // nheads |
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_c = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
|
q_ptrs = ( |
|
Q + |
|
off_b * stride_qb + |
|
off_h * stride_qh + |
|
(offs_m[:, None] * stride_qm + offs_d[None, :]) |
|
) |
|
k_ptrs = ( |
|
K + |
|
off_b * stride_kb + |
|
off_h * stride_kh + |
|
(offs_n[:, None] * stride_kn + offs_d[None, :]) |
|
) |
|
v_ptrs = ( |
|
V + |
|
off_b * stride_vb + |
|
off_h * stride_vh + |
|
(offs_n[:, None] * stride_vn + offs_d[None, :]) |
|
) |
|
if EMPTY_RFA_KV == 0: |
|
rfa_k_ptrs = ( |
|
RFA_K + |
|
off_b * stride_rfa_kb + |
|
off_h * stride_rfa_kh + |
|
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :]) |
|
) |
|
rfa_v_ptrs = ( |
|
RFA_V + |
|
off_b * stride_rfa_vb + |
|
off_h * stride_rfa_vh + |
|
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :]) |
|
) |
|
|
|
qk_scale = softmax_scale |
|
qk_scale *= 1.4426950408889634 |
|
if MASK_TYPE == 1: |
|
window_mask_ptrs = ( |
|
WindowMask + |
|
off_b * stride_window_mask_b + |
|
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :]) |
|
) |
|
if EMPTY_RFA_KV == 0: |
|
chunk_mask_ptrs = ( |
|
ChunkMask + |
|
off_b * stride_chunk_mask_b + |
|
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :]) |
|
) |
|
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") |
|
d_i = tl.zeros([BLOCK_M], dtype=tl.float32) |
|
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) |
|
|
|
|
|
|
|
if EVEN_M & EVEN_N: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=offs_m[:, None] < seqlen_q, |
|
other=0.0 |
|
) |
|
else: |
|
q = tl.load( |
|
q_ptrs, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), |
|
other=0.0 |
|
) |
|
|
|
|
|
|
|
start_idx_n = offs_w * WINDOW_SIZE |
|
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) |
|
for start_n in range(start_idx_n, end_idx_n, BLOCK_N): |
|
start_n = tl.multiple_of(start_n, BLOCK_N) |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn |
|
) |
|
else: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0, |
|
) |
|
else: |
|
k = tl.load( |
|
k_ptrs + start_n * stride_kn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(k)) |
|
|
|
if not EVEN_N: |
|
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_M & EVEN_W: |
|
window_mask = tl.load( |
|
window_mask_ptrs + start_n - start_idx_n |
|
) |
|
else: |
|
window_mask = tl.load( |
|
window_mask_ptrs + start_n - start_idx_n, |
|
mask=(offs_m[:, None] < seqlen_q) |
|
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE), |
|
other=1, |
|
) |
|
|
|
|
|
|
|
|
|
qk = qk * qk_scale |
|
qk = tl.where(window_mask, float("-inf"), qk) |
|
m_ij = tl.maximum(tl.max(qk, 1), m_i) |
|
masked_out_rows = (m_ij == float("-inf")) |
|
m_ij_masked = tl.where(masked_out_rows, 0, m_ij) |
|
p = tl.exp2(qk - m_ij_masked[:, None]) |
|
else: |
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) |
|
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i) |
|
masked_out_rows = (m_ij == float("-inf")) |
|
m_ij_masked = tl.where(masked_out_rows, 0, m_ij) |
|
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None]) |
|
|
|
d_ij = tl.sum(p, 1) |
|
|
|
|
|
prev_scale = tl.exp2(m_i - m_ij_masked) |
|
|
|
acc_o = acc_o * prev_scale[:, None] |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn |
|
) |
|
else: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0, |
|
) |
|
else: |
|
v = tl.load( |
|
v_ptrs + start_n * stride_vn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
p = p.to(v.dtype) |
|
acc_o = tl.dot(p, v, acc_o) |
|
|
|
|
|
d_i = d_i * prev_scale + d_ij |
|
m_i = m_ij |
|
|
|
if EMPTY_RFA_KV == 0: |
|
|
|
|
|
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks) |
|
for start_c in range(0, end_idx_c, BLOCK_N): |
|
start_c = tl.multiple_of(start_c, BLOCK_N) |
|
|
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc |
|
) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=(start_c + offs_c)[:, None] < nchunks, |
|
other=0.0, |
|
) |
|
else: |
|
rfa_k = tl.load( |
|
rfa_k_ptrs + start_c * stride_rfa_kc, |
|
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, tl.trans(rfa_k)) |
|
|
|
if not EVEN_C: |
|
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf")) |
|
|
|
if MASK_TYPE == 1: |
|
if EVEN_C & EVEN_M: |
|
chunk_mask = tl.load( |
|
chunk_mask_ptrs + start_c |
|
) |
|
else: |
|
chunk_mask = tl.load( |
|
chunk_mask_ptrs + start_c, |
|
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks), |
|
other=1, |
|
) |
|
|
|
|
|
|
|
|
|
qk = qk * qk_scale |
|
qk = tl.where(chunk_mask, float("-inf"), qk) |
|
m_ij = tl.maximum(tl.max(qk, 1), m_i) |
|
masked_out_rows = (m_ij == float("-inf")) |
|
m_ij_masked = tl.where(masked_out_rows, 0, m_ij) |
|
p = tl.exp2(qk - m_ij_masked[:, None]) |
|
else: |
|
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i) |
|
masked_out_rows = (m_ij == float("-inf")) |
|
m_ij_masked = tl.where(masked_out_rows, 0, m_ij) |
|
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None]) |
|
|
|
d_ij = tl.sum(p, 1) |
|
|
|
|
|
prev_scale = tl.exp2(m_i - m_ij_masked) |
|
|
|
acc_o = acc_o * prev_scale[:, None] |
|
|
|
|
|
if EVEN_C & EVEN_M: |
|
if EVEN_HEADDIM: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc |
|
) |
|
else: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0 |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=(start_c + offs_n)[:, None] < nchunks, |
|
other=0.0, |
|
) |
|
else: |
|
rfa_v = tl.load( |
|
rfa_v_ptrs + start_c * stride_rfa_vc, |
|
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim), |
|
other=0.0, |
|
) |
|
p = p.to(rfa_v.dtype) |
|
acc_o = tl.dot(p, rfa_v, acc_o) |
|
|
|
|
|
d_i = d_i * prev_scale + d_ij |
|
m_i = m_ij |
|
|
|
|
|
d_i = tl.where(d_i == 0.0, 1.0, d_i) |
|
|
|
lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453 |
|
acc_o = acc_o / d_i[:, None] |
|
|
|
start_m = tl.program_id(0) |
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
out_ptrs = ( |
|
Out + |
|
off_b * stride_ob + |
|
off_h * stride_oh + |
|
(offs_m[:, None] * stride_om + offs_d[None, :]) |
|
) |
|
if EVEN_M: |
|
if EVEN_HEADDIM: |
|
tl.store( |
|
out_ptrs, acc_o |
|
) |
|
else: |
|
tl.store( |
|
out_ptrs, acc_o, |
|
mask=offs_d[None, :] < headdim |
|
) |
|
else: |
|
if EVEN_HEADDIM: |
|
tl.store( |
|
out_ptrs, acc_o, |
|
mask=offs_m[:, None] < seqlen_q |
|
) |
|
else: |
|
tl.store( |
|
out_ptrs, acc_o, |
|
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) |
|
) |
|
lse_ptrs = ( |
|
LSE + |
|
off_b * stride_lse_b + |
|
off_h * stride_lse_h + |
|
offs_m |
|
) |
|
if EVEN_M: |
|
tl.store( |
|
lse_ptrs, lse_m, |
|
) |
|
else: |
|
tl.store( |
|
lse_ptrs, lse_m, |
|
mask=offs_m < seqlen_q |
|
) |
|
|
|
def triton_eva_agg_fwd( |
|
q, k, v, rfa_k, rfa_v, |
|
window_mask, |
|
chunk_mask, |
|
softmax_scale, |
|
window_size, |
|
chunks_per_window |
|
): |
|
if rfa_k is None and rfa_v is None: |
|
empty_rfa_kv = 1 |
|
|
|
q, k, v = [ |
|
x if x.stride(-1) == 1 else x.contiguous() |
|
for x in [q, k, v] |
|
] |
|
else: |
|
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time." |
|
empty_rfa_kv = 0 |
|
|
|
q, k, v, rfa_k, rfa_v = [ |
|
x if x.stride(-1) == 1 else x.contiguous() |
|
for x in [q, k, v, rfa_k, rfa_v] |
|
] |
|
|
|
|
|
batch, nheads, seqlen_q, head_dim = q.shape |
|
_, _, seqlen_k, _ = k.shape |
|
if empty_rfa_kv == 0: |
|
nchunks = rfa_k.shape[-2] |
|
assert rfa_k.shape == (batch, nheads, nchunks, head_dim) |
|
assert rfa_v.shape == (batch, nheads, nchunks, head_dim) |
|
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype |
|
else: |
|
nchunks = 0 |
|
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" |
|
assert k.shape == (batch, nheads, seqlen_k, head_dim) |
|
assert v.shape == (batch, nheads, seqlen_k, head_dim) |
|
|
|
assert head_dim <= 128, "We only test head dimensions up to 128" |
|
|
|
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now" |
|
assert q.is_cuda and k.is_cuda and v.is_cuda |
|
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) |
|
|
|
mask_type = 0 |
|
if window_mask is not None: |
|
mask_type = 1 |
|
assert window_mask.dtype == torch.bool |
|
assert window_mask.is_cuda |
|
assert window_mask.dim() == 4 |
|
assert window_mask.shape == (batch, 1, seqlen_q, window_size) |
|
if window_mask.stride(-1) != 1: |
|
window_mask = window_mask.contiguous() |
|
|
|
assert chunk_mask is not None |
|
assert chunk_mask.dtype == torch.bool |
|
assert chunk_mask.is_cuda |
|
assert chunk_mask.dim() == 4 |
|
assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks) |
|
if chunk_mask.stride(-1) != 1: |
|
chunk_mask = chunk_mask.contiguous() |
|
|
|
chunk_mask_strides = ( |
|
(chunk_mask.stride(0), chunk_mask.stride(2)) |
|
if mask_type == 1 else |
|
(0, 0) |
|
) |
|
window_mask_strides = ( |
|
(window_mask.stride(0), window_mask.stride(2)) |
|
if mask_type == 1 else |
|
(0, 0) |
|
) |
|
|
|
rfa_k_strides = ( |
|
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
rfa_v_strides = ( |
|
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
|
|
o = torch.empty_like(q) |
|
lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) |
|
|
|
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) |
|
|
|
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd") |
|
|
|
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK" |
|
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N" |
|
|
|
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) |
|
_fwd_eva_agg_kernel[grid]( |
|
q, |
|
k, |
|
v, |
|
rfa_k, |
|
rfa_v, |
|
window_mask, |
|
chunk_mask, |
|
o, |
|
lse, |
|
softmax_scale, |
|
q.stride(0), q.stride(1), q.stride(2), |
|
k.stride(0), k.stride(1), k.stride(2), |
|
v.stride(0), v.stride(1), v.stride(2), |
|
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], |
|
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], |
|
window_mask_strides[0], window_mask_strides[1], |
|
chunk_mask_strides[0], chunk_mask_strides[1], |
|
o.stride(0), o.stride(1), o.stride(2), |
|
lse.stride(0), lse.stride(1), |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
nchunks, |
|
head_dim, |
|
chunks_per_window, |
|
window_size, |
|
mask_type, |
|
empty_rfa_kv, |
|
BLOCK_HEADDIM, |
|
BLOCK_M=BLOCK_M, |
|
BLOCK_N=BLOCK_N, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
return o, lse |
|
|
|
def triton_eva_agg_bwd( |
|
do, |
|
q, k, v, rfa_k, rfa_v, |
|
window_mask, chunk_mask, |
|
o, lse, |
|
dq, dk, dv, d_rfa_k, d_rfa_v, |
|
softmax_scale, |
|
window_size, |
|
chunks_per_window, |
|
empty_rfa_kv, |
|
mask_type, |
|
): |
|
if do.stride(-1) != 1: |
|
do = do.contiguous() |
|
|
|
|
|
batch, nheads, seqlen_q, head_dim = q.shape |
|
_, _, seqlen_k, _ = k.shape |
|
if empty_rfa_kv == 0: |
|
nchunks = rfa_k.shape[-2] |
|
assert rfa_k.shape == (batch, nheads, nchunks, head_dim) |
|
assert rfa_v.shape == (batch, nheads, nchunks, head_dim) |
|
assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1 |
|
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype |
|
else: |
|
nchunks = 0 |
|
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" |
|
|
|
assert lse.shape == (batch, nheads, seqlen_q) |
|
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1 |
|
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 |
|
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim) |
|
|
|
assert head_dim <= 128, "We only test head dimensions up to 128" |
|
|
|
window_mask_strides = ( |
|
(window_mask.stride(0), window_mask.stride(2)) |
|
if mask_type == 1 else |
|
(0, 0) |
|
) |
|
chunk_mask_strides = ( |
|
(chunk_mask.stride(0), chunk_mask.stride(2)) |
|
if mask_type == 1 else |
|
(0, 0) |
|
) |
|
|
|
rfa_k_strides = ( |
|
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
rfa_v_strides = ( |
|
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
|
|
d_rfa_k_strides = ( |
|
(d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
d_rfa_v_strides = ( |
|
(d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2)) |
|
if empty_rfa_kv == 0 else |
|
(0, 0, 0) |
|
) |
|
|
|
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16) |
|
|
|
do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype) |
|
|
|
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq") |
|
|
|
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK" |
|
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK" |
|
grid = lambda META: ( |
|
triton.cdiv(seqlen_q, META["BLOCK_M"]), |
|
batch * nheads, |
|
) |
|
_bwd_eva_agg_kernel_dq[grid]( |
|
q, |
|
k, |
|
v, |
|
rfa_k, |
|
rfa_v, |
|
window_mask, |
|
chunk_mask, |
|
do, |
|
lse, |
|
do_t_o, |
|
dq, |
|
softmax_scale, |
|
q.stride(0), q.stride(1), q.stride(2), |
|
k.stride(0), k.stride(1), k.stride(2), |
|
v.stride(0), v.stride(1), v.stride(2), |
|
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], |
|
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], |
|
window_mask_strides[0], window_mask_strides[1], |
|
chunk_mask_strides[0], chunk_mask_strides[1], |
|
do.stride(0), do.stride(1), do.stride(2), |
|
lse.stride(0), lse.stride(1), |
|
do_t_o.stride(0), do_t_o.stride(1), |
|
dq.stride(0), dq.stride(1), dq.stride(2), |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
nchunks, |
|
head_dim, |
|
chunks_per_window, |
|
window_size, |
|
mask_type, |
|
empty_rfa_kv, |
|
BLOCK_HEADDIM, |
|
BLOCK_M=BLOCK_M, |
|
BLOCK_N=BLOCK_N, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
|
|
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv") |
|
grid = lambda META: ( |
|
triton.cdiv(seqlen_k, META["BLOCK_N"]), |
|
batch * nheads, |
|
) |
|
_bwd_eva_agg_kernel_dkdv[grid]( |
|
q, |
|
k, |
|
v, |
|
window_mask, |
|
do, |
|
lse, |
|
do_t_o, |
|
dk, |
|
dv, |
|
softmax_scale, |
|
q.stride(0), q.stride(1), q.stride(2), |
|
k.stride(0), k.stride(1), k.stride(2), |
|
v.stride(0), v.stride(1), v.stride(2), |
|
window_mask_strides[0], window_mask_strides[1], |
|
do.stride(0), do.stride(1), do.stride(2), |
|
lse.stride(0), lse.stride(1), |
|
do_t_o.stride(0), do_t_o.stride(1), |
|
dk.stride(0), dk.stride(1), dk.stride(2), |
|
dv.stride(0), dv.stride(1), dv.stride(2), |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
head_dim, |
|
window_size, |
|
mask_type, |
|
BLOCK_HEADDIM, |
|
BLOCK_M=BLOCK_M, |
|
BLOCK_N=BLOCK_N, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
if empty_rfa_kv == 0: |
|
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv") |
|
grid = lambda META: ( |
|
triton.cdiv(nchunks, META["BLOCK_N"]), |
|
batch * nheads, |
|
) |
|
_bwd_eva_agg_kernel_drfa_kv[grid]( |
|
q, |
|
rfa_k, |
|
rfa_v, |
|
chunk_mask, |
|
do, |
|
lse, |
|
do_t_o, |
|
d_rfa_k, |
|
d_rfa_v, |
|
softmax_scale, |
|
q.stride(0), q.stride(1), q.stride(2), |
|
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2], |
|
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2], |
|
chunk_mask_strides[0], chunk_mask_strides[1], |
|
do.stride(0), do.stride(1), do.stride(2), |
|
lse.stride(0), lse.stride(1), |
|
do_t_o.stride(0), do_t_o.stride(1), |
|
d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2], |
|
d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2], |
|
nheads, |
|
seqlen_q, |
|
nchunks, |
|
head_dim, |
|
chunks_per_window, |
|
window_size, |
|
mask_type, |
|
BLOCK_HEADDIM, |
|
BLOCK_M=BLOCK_M, |
|
BLOCK_N=BLOCK_N, |
|
num_warps=num_warps, |
|
num_stages=num_stages, |
|
) |
|
|
|
|
|
class EvaAggFunc(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None): |
|
if rfa_k is None and rfa_v is None: |
|
empty_rfa_kv = 1 |
|
else: |
|
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time." |
|
empty_rfa_kv = 0 |
|
|
|
if window_mask is not None: |
|
mask_type = 1 |
|
else: |
|
mask_type = 0 |
|
o, lse = triton_eva_agg_fwd( |
|
q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window |
|
) |
|
ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask) |
|
ctx.softmax_scale = softmax_scale |
|
ctx.window_size = window_size |
|
ctx.chunks_per_window = chunks_per_window |
|
ctx.empty_rfa_kv = empty_rfa_kv |
|
ctx.mask_type = mask_type |
|
return o |
|
|
|
@staticmethod |
|
def backward(ctx, do): |
|
q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors |
|
dq = torch.empty_like(q) |
|
dk = torch.empty_like(k) |
|
dv = torch.empty_like(v) |
|
if ctx.empty_rfa_kv == 0: |
|
d_rfa_k = torch.empty_like(rfa_k) |
|
d_rfa_v = torch.empty_like(rfa_v) |
|
else: |
|
d_rfa_k = None |
|
d_rfa_v = None |
|
triton_eva_agg_bwd( |
|
do, |
|
q, |
|
k, |
|
v, |
|
rfa_k, |
|
rfa_v, |
|
window_mask, |
|
chunk_mask, |
|
o, |
|
lse, |
|
dq, |
|
dk, |
|
dv, |
|
d_rfa_k, |
|
d_rfa_v, |
|
softmax_scale=ctx.softmax_scale, |
|
window_size=ctx.window_size, |
|
chunks_per_window=ctx.chunks_per_window, |
|
empty_rfa_kv=ctx.empty_rfa_kv, |
|
mask_type=ctx.mask_type, |
|
) |
|
return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None |
|
|
|
|
|
def eva_agg_func_triton( |
|
q, k, v, rfa_k, rfa_v, |
|
window_mask, chunk_mask, |
|
softmax_scale=None, window_size=None, chunks_per_window=None, |
|
): |
|
return EvaAggFunc.apply( |
|
q, k, v, rfa_k, rfa_v, |
|
window_mask, chunk_mask, |
|
softmax_scale, window_size, chunks_per_window, |
|
) |
|
|