EvaByte-SFT / eva_agg_kernel.py
linzheng's picture
Update model and kernels for training support
d927f57 verified
import math
import torch
import triton
import triton.language as tl
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _bwd_eva_agg_kernel_dkdv(
Q,
K,
V,
WindowMask,
DO,
LSE,
DO_T_O,
DK,
DV,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_window_mask_b, stride_window_mask_m,
stride_do_b, stride_do_h, stride_do_m,
stride_lse_b, stride_lse_h,
stride_do_t_o_b, stride_do_t_o_h,
stride_dk_b, stride_dk_h, stride_dk_n,
stride_dv_b, stride_dv_h, stride_dv_n,
nheads,
seqlen_q,
seqlen_k,
headdim,
WINDOW_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_W: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_bh = tl.program_id(1)
off_h = off_bh % nheads
off_b = off_bh // nheads
start_n = tl.program_id(0)
# determine which window the current KV block belongs to
offs_w = (start_n * BLOCK_N) // WINDOW_SIZE
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# initialize pointers
q_ptrs = (
Q +
off_b * stride_qb +
off_h * stride_qh +
offs_m[:, None] * stride_qm + offs_d[None, :]
)
k_ptrs = (
K +
off_b * stride_kb +
off_h * stride_kh +
offs_n[:, None] * stride_kn + offs_d[None, :]
)
v_ptrs = (
V +
off_b * stride_vb +
off_h * stride_vh +
offs_n[:, None] * stride_vn + offs_d[None, :]
)
do_ptrs = (
DO +
off_b * stride_do_b +
off_h * stride_do_h +
offs_m[:, None] * stride_do_m + offs_d[None, :]
)
do_t_o_ptrs = (
DO_T_O +
off_b * stride_do_t_o_b +
off_h * stride_do_t_o_h +
offs_m[:, None]
)
lse_ptrs = (
LSE +
off_b * stride_lse_b +
off_h * stride_lse_h +
offs_m[:, None]
)
if MASK_TYPE == 1:
m_ptrs = (
WindowMask +
off_b * stride_window_mask_b +
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
)
dk_ptrs = (
DK +
off_b * stride_dk_b +
off_h * stride_dk_h +
offs_n[:, None] * stride_dk_n + offs_d[None, :]
)
dv_ptrs = (
DV +
off_b * stride_dv_b +
off_h * stride_dv_h +
offs_n[:, None] * stride_dv_n + offs_d[None, :]
)
# 1. for singletons
# determine start and end of query block
begin_m = ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
end_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
else:
k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
else:
k = tl.load(
k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
)
v = tl.load(
v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
)
for start_m in range(begin_m, end_m, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
# load q, do, and lse
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs + start_m * stride_qm
)
do = tl.load(
do_ptrs + start_m * stride_do_m
)
else:
q = tl.load(
q_ptrs + start_m * stride_qm,
mask=offs_d[None, :] < headdim,
other=0.0
)
do = tl.load(
do_ptrs + start_m * stride_do_m,
mask=offs_d[None, :] < headdim,
other=0.0
)
do_t_o = tl.load(
do_t_o_ptrs + start_m
)
lse = tl.load(
lse_ptrs + start_m
)
else:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs + start_m * stride_qm,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
do = tl.load(
do_ptrs + start_m * stride_do_m,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
else:
q = tl.load(
q_ptrs + start_m * stride_qm,
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
do = tl.load(
do_ptrs + start_m * stride_do_m,
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
do_t_o = tl.load(
do_t_o_ptrs + start_m,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
lse = tl.load(
lse_ptrs + start_m,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
lse = tl.where(lse == float("-inf"), 0.0, lse)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
if not EVEN_M:
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_M & EVEN_W:
mask = tl.load(
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE)
)
else:
mask = tl.load(
m_ptrs + (start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE),
mask=((start_m + offs_m)[:, None] < seqlen_q)
& (((start_m * stride_window_mask_m) - (offs_w * WINDOW_SIZE) + offs_n)[None, :] < WINDOW_SIZE),
other=1,
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * softmax_scale
qk = tl.where(mask, float("-inf"), qk)
p = tl.exp(qk - lse)
else:
qk += tl.where((start_m + offs_m)[:, None] >= offs_n[None, :], 0, float("-inf"))
p = tl.exp(qk * softmax_scale - lse)
# dp [M, N]
dp = tl.dot(do, tl.trans(v))
# p [M, N], dp [M, N], do_t_o [M, 1] -> ds [M, N]
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
# p is fp32 and [M, N], convert to q.dtype
# do [M, D] -> dv [N, D]
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# dk [N, D]
dk += tl.dot(tl.trans(ds), q)
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
else:
tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
else:
tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _bwd_eva_agg_kernel_drfa_kv(
Q,
RFA_K,
RFA_V,
ChunkMask,
DO,
LSE,
DO_T_O,
D_RFA_K,
D_RFA_V,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
stride_chunk_mask_b, stride_chunk_mask_m,
stride_do_b, stride_do_h, stride_do_m,
stride_lse_b, stride_lse_h,
stride_do_t_o_b, stride_do_t_o_h,
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
nheads,
seqlen_q,
nchunks,
headdim,
CHUNKS_PER_WINDOW: tl.constexpr,
WINDOW_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_C: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_bh = tl.program_id(1)
off_h = off_bh % nheads
off_b = off_bh // nheads
start_c = tl.program_id(0)
# there are 128 chunks per window
offs_c = start_c * BLOCK_N + tl.arange(0, BLOCK_N)
# determine which window the current KV block belongs to
offs_w = (start_c * BLOCK_N) // CHUNKS_PER_WINDOW
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# initialize pointers
q_ptrs = (
Q +
off_b * stride_qb +
off_h * stride_qh +
(offs_m[:, None] * stride_qm + offs_d[None, :])
)
do_ptrs = (
DO +
off_b * stride_do_b +
off_h * stride_do_h +
(offs_m[:, None] * stride_do_m + offs_d[None, :])
)
do_t_o_ptrs = (
DO_T_O +
off_b * stride_do_t_o_b +
off_h * stride_do_t_o_h +
(offs_m[:, None])
)
lse_ptrs = (
LSE +
off_b * stride_lse_b +
off_h * stride_lse_h +
(offs_m[:, None])
)
rfa_k_ptrs = (
RFA_K +
off_b * stride_rfa_kb +
off_h * stride_rfa_kh +
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
)
rfa_v_ptrs = (
RFA_V +
off_b * stride_rfa_vb +
off_h * stride_rfa_vh +
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
)
if MASK_TYPE == 1:
rfa_m_ptrs = (
ChunkMask +
off_b * stride_chunk_mask_b +
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
)
d_rfa_k_ptrs = (
D_RFA_K +
off_b * stride_d_rfa_k_b +
off_h * stride_d_rfa_k_h +
(offs_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
)
d_rfa_v_ptrs = (
D_RFA_V +
off_b * stride_d_rfa_v_b +
off_h * stride_d_rfa_v_h +
(offs_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
)
d_rfa_k = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
d_rfa_v = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_k = tl.load(rfa_k_ptrs)
rfa_v = tl.load(rfa_v_ptrs)
else:
rfa_k = tl.load(rfa_k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
rfa_v = tl.load(rfa_v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
rfa_k = tl.load(rfa_k_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
rfa_v = tl.load(rfa_v_ptrs, mask=offs_c[:, None] < nchunks, other=0.0)
else:
rfa_k = tl.load(
rfa_k_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
)
rfa_v = tl.load(
rfa_v_ptrs, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim), other=0.0
)
begin_m = tl.minimum((offs_w + 1) * WINDOW_SIZE, seqlen_q)
end_m = seqlen_q
for start_m in range(begin_m, end_m, BLOCK_M):
start_m = tl.multiple_of(start_m, BLOCK_M)
# load q, do, and lse
if EVEN_M:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs + start_m * stride_qm
)
do = tl.load(
do_ptrs + start_m * stride_do_m
)
else:
q = tl.load(
q_ptrs + start_m * stride_qm,
mask=offs_d[None, :] < headdim,
other=0.0
)
do = tl.load(
do_ptrs + start_m * stride_do_m,
mask=offs_d[None, :] < headdim,
other=0.0
)
do_t_o = tl.load(
do_t_o_ptrs + start_m
)
lse = tl.load(
lse_ptrs + start_m
)
else:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs + start_m * stride_qm,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
do = tl.load(
do_ptrs + start_m * stride_do_m,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
else:
q = tl.load(
q_ptrs + start_m * stride_qm,
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
do = tl.load(
do_ptrs + start_m * stride_do_m,
mask=((start_m + offs_m)[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
do_t_o = tl.load(
do_t_o_ptrs + start_m,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
lse = tl.load(
lse_ptrs + start_m,
mask=(start_m + offs_m)[:, None] < seqlen_q,
other=0.0
)
lse = tl.where(lse == float("-inf"), 0.0, lse)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(rfa_k))
if not EVEN_M:
qk += tl.where((start_m + offs_m)[:, None] < seqlen_q, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_M & EVEN_C:
mask = tl.load(
rfa_m_ptrs + (start_m * stride_chunk_mask_m)
)
else:
mask = tl.load(
rfa_m_ptrs + (start_m * stride_chunk_mask_m),
mask=((start_m + offs_m)[:, None] < seqlen_q)
& (offs_c[None, :] < nchunks),
other=1,
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * softmax_scale
qk = tl.where(mask, float("-inf"), qk)
p = tl.exp(qk - lse)
else:
p = tl.exp(qk * softmax_scale - lse)
dp = tl.dot(do, tl.trans(rfa_v))
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
# p is fp32, convert to q.dtype
d_rfa_v += tl.dot(tl.trans(p.to(do.dtype)), do)
# move softmax_scale to ds to save computation
d_rfa_k += tl.dot(tl.trans(ds), q)
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
tl.store(d_rfa_v_ptrs, d_rfa_v)
tl.store(d_rfa_k_ptrs, d_rfa_k)
else:
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_d[None, :] < headdim)
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=offs_c[:, None] < nchunks)
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=offs_c[:, None] < nchunks)
else:
tl.store(d_rfa_v_ptrs, d_rfa_v, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
tl.store(d_rfa_k_ptrs, d_rfa_k, mask=(offs_c[:, None] < nchunks) & (offs_d[None, :] < headdim))
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _bwd_eva_agg_kernel_dq(
Q,
K,
V,
RFA_K,
RFA_V,
WindowMask,
ChunkMask,
DO,
LSE,
DO_T_O,
DQ,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
stride_window_mask_b, stride_window_mask_m,
stride_chunk_mask_b, stride_chunk_mask_m,
stride_do_b, stride_do_h, stride_do_m,
stride_lse_b, stride_lse_h,
stride_do_t_o_b, stride_do_t_o_h,
stride_dq_b, stride_dq_h, stride_dq_m,
nheads,
seqlen_q,
seqlen_k,
nchunks,
headdim,
CHUNKS_PER_WINDOW: tl.constexpr,
WINDOW_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
EMPTY_RFA_KV: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_W: tl.constexpr,
EVEN_C: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_bh = tl.program_id(1)
off_h = off_bh % nheads
off_b = off_bh // nheads
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
offs_n = tl.arange(0, BLOCK_N)
offs_c = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# TODO: add paratheses or not
q_ptrs = (
Q +
off_b * stride_qb +
off_h * stride_qh +
(offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K +
off_b * stride_kb +
off_h * stride_kh +
(offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V +
off_b * stride_vb +
off_h * stride_vh +
(offs_n[:, None] * stride_vn + offs_d[None, :])
)
if EMPTY_RFA_KV == 0:
rfa_k_ptrs = (
RFA_K +
off_b * stride_rfa_kb +
off_h * stride_rfa_kh +
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
)
rfa_v_ptrs = (
RFA_V +
off_b * stride_rfa_vb +
off_h * stride_rfa_vh +
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
)
dq_ptrs = (
DQ +
off_b * stride_dq_b +
off_h * stride_dq_h +
(offs_m[:, None] * stride_dq_m + offs_d[None, :])
)
do_ptrs = (
DO +
off_b * stride_do_b +
off_h * stride_do_h +
(offs_m[:, None] * stride_do_m + offs_d[None, :])
)
do_t_o_ptrs = (
DO_T_O +
off_b * stride_do_t_o_b +
off_h * stride_do_t_o_h +
offs_m[:, None]
)
lse_ptrs = (
LSE +
off_b * stride_lse_b +
off_h * stride_lse_h +
offs_m[:, None]
)
### load q, do, do_t_o, lse ####
if EVEN_M:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs
)
do = tl.load(
do_ptrs
)
else:
q = tl.load(
q_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
do = tl.load(
do_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
do_t_o = tl.load(
do_t_o_ptrs
)
lse = tl.load(
lse_ptrs
)
else:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
do = tl.load(
do_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
else:
q = tl.load(
q_ptrs,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
do = tl.load(
do_ptrs,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
do_t_o = tl.load(
do_t_o_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
lse = tl.load(
lse_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
lse = tl.where(lse == float("-inf"), 0.0, lse)
lse *= 1.4426950408889634 # log2(e)
qk_scale = softmax_scale
qk_scale *= 1.4426950408889634 # log2(e)
if MASK_TYPE == 1:
window_mask_ptrs = (
WindowMask +
off_b * stride_window_mask_b +
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
)
if EMPTY_RFA_KV == 0:
chunk_mask_ptrs = (
ChunkMask +
off_b * stride_chunk_mask_b +
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
)
dq = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# loop over k, v and update accumulator
# Iterate over local singletons;
# so we only iterate over blocks within the current window
start_idx_n = offs_w * WINDOW_SIZE
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_M & EVEN_W:
window_mask = tl.load(
window_mask_ptrs + start_n - start_idx_n
)
else:
window_mask = tl.load(
window_mask_ptrs + start_n - start_idx_n,
mask=(offs_m[:, None] < seqlen_q)
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
other=1,
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * qk_scale
qk = tl.where(window_mask, float("-inf"), qk)
p = tl.exp2(qk - lse)
else:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
p = tl.exp2(qk * qk_scale - lse)
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
dp = tl.dot(do, tl.trans(v))
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
dq += tl.dot(ds, k)
if EMPTY_RFA_KV == 0:
# Iterate over RFA chunks
# we only iterate over chunks before the current local singleton window
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
for start_c in range(0, end_idx_c, BLOCK_N):
start_c = tl.multiple_of(start_c, BLOCK_N)
# -- compute qk ----
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=(start_c + offs_c)[:, None] < nchunks,
other=0.0,
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(rfa_k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_C & EVEN_M:
chunk_mask = tl.load(
chunk_mask_ptrs + start_c
)
else:
chunk_mask = tl.load(
chunk_mask_ptrs + start_c,
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
other=1,
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * qk_scale
qk = tl.where(chunk_mask, float("-inf"), qk)
p = tl.exp2(qk - lse)
else:
p = tl.exp2(qk * qk_scale - lse)
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=(start_c + offs_n)[:, None] < nchunks,
other=0.0,
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
dp = tl.dot(do, tl.trans(rfa_v))
ds = (p * (dp - do_t_o) * softmax_scale).to(q.dtype)
dq += tl.dot(ds, rfa_k)
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
dq_ptrs = (
DQ +
off_b * stride_dq_b +
off_h * stride_dq_h +
(offs_m[:, None] * stride_dq_m + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(
dq_ptrs, dq
)
else:
tl.store(
dq_ptrs, dq,
mask=offs_d[None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
dq_ptrs, dq,
mask=offs_m[:, None] < seqlen_q
)
else:
tl.store(
dq_ptrs, dq,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
_capability_90_config = {
"fwd": {
(torch.bfloat16, 64): (128, 128, 4, 3),
(torch.bfloat16, 128): (128, 128, 8, 3),
(torch.float32, 64): (128, 64, 8, 3),
(torch.float32, 128): (64, 32, 4, 3),
},
"bwd_dq": {
(torch.bfloat16, 64): (128, 64, 4, 3),
(torch.bfloat16, 128): (128, 64, 8, 3),
(torch.float32, 64): (128, 64, 8, 2),
(torch.float32, 128): (32, 32, 4, 2),
},
"bwd_dkdv": {
(torch.bfloat16, 64): (128, 64, 4, 2),
(torch.bfloat16, 128): (128, 64, 8, 2),
(torch.float32, 64): (128, 64, 8, 2),
(torch.float32, 128): (32, 32, 4, 1),
},
"bwd_drfa_kv": {
(torch.bfloat16, 64): (128, 64, 4, 2),
(torch.bfloat16, 128): (128, 64, 8, 2),
(torch.float32, 64): (128, 64, 8, 2),
(torch.float32, 128): (32, 32, 4, 1),
}
}
_capability_80_config = {
"fwd": {
(torch.bfloat16, 64): (64, 64, 4, 3),
(torch.bfloat16, 128): (64, 64, 8, 3),
(torch.float32, 64): (64, 32, 4, 2),
(torch.float32, 128): (64, 32, 8, 1),
},
"bwd_dq": {
(torch.bfloat16, 64): (64, 64, 4, 3),
(torch.bfloat16, 128): (64, 32, 4, 2),
(torch.float32, 64): (32, 32, 4, 2),
(torch.float32, 128): (32, 32, 4, 2),
},
"bwd_dkdv": {
(torch.bfloat16, 64): (64, 64, 4, 3),
(torch.bfloat16, 128): (32, 32, 4, 2),
(torch.float32, 64): (32, 32, 4, 1),
(torch.float32, 128): (16, 64, 8, 1),
},
"bwd_drfa_kv": {
(torch.bfloat16, 64): (64, 64, 4, 3),
(torch.bfloat16, 128): (64, 32, 4, 3),
(torch.float32, 64): (32, 32, 4, 1),
(torch.float32, 128): (32, 32, 4, 1),
}
}
def _get_config(dtype, head_dim, mode) -> tuple[int, int, int, int]:
capability = torch.cuda.get_device_capability()
if capability >= (9, 0):
kernel_config = _capability_90_config[mode].get((dtype, head_dim), (32, 32, 4, 1))
elif capability >= (8, 0):
kernel_config = _capability_80_config[mode].get((dtype, head_dim), (16, 16, 4, 1))
else:
if mode == "fwd":
if dtype == torch.float32:
kernel_config = (32, 16, 4, 2)
else:
kernel_config = (64, 32, 4, 2)
else:
if dtype == torch.float32:
kernel_config = (16, 16, 4, 1)
else:
kernel_config = (32, 32, 4, 1)
return kernel_config
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
"EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_eva_agg_kernel(
Q,
K,
V,
RFA_K,
RFA_V,
WindowMask,
ChunkMask,
Out,
LSE,
softmax_scale,
stride_qb, stride_qh, stride_qm,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
stride_window_mask_b, stride_window_mask_m,
stride_chunk_mask_b, stride_chunk_mask_m,
stride_ob, stride_oh, stride_om,
stride_lse_b, stride_lse_h,
nheads,
seqlen_q,
seqlen_k,
nchunks,
headdim,
CHUNKS_PER_WINDOW: tl.constexpr,
WINDOW_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
EMPTY_RFA_KV: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_W: tl.constexpr,
EVEN_C: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_bh = tl.program_id(1)
off_h = off_bh % nheads
off_b = off_bh // nheads
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
offs_n = tl.arange(0, BLOCK_N)
offs_c = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# TODO: add paratheses or not
q_ptrs = (
Q +
off_b * stride_qb +
off_h * stride_qh +
(offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K +
off_b * stride_kb +
off_h * stride_kh +
(offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V +
off_b * stride_vb +
off_h * stride_vh +
(offs_n[:, None] * stride_vn + offs_d[None, :])
)
if EMPTY_RFA_KV == 0:
rfa_k_ptrs = (
RFA_K +
off_b * stride_rfa_kb +
off_h * stride_rfa_kh +
(offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
)
rfa_v_ptrs = (
RFA_V +
off_b * stride_rfa_vb +
off_h * stride_rfa_vh +
(offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
)
qk_scale = softmax_scale
qk_scale *= 1.4426950408889634 # log2(e)
if MASK_TYPE == 1:
window_mask_ptrs = (
WindowMask +
off_b * stride_window_mask_b +
(offs_m[:, None] * stride_window_mask_m + offs_n[None, :])
)
if EMPTY_RFA_KV == 0:
chunk_mask_ptrs = (
ChunkMask +
off_b * stride_chunk_mask_b +
(offs_m[:, None] * stride_chunk_mask_m + offs_c[None, :])
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs
)
else:
q = tl.load(
q_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
q = tl.load(
q_ptrs,
mask=offs_m[:, None] < seqlen_q,
other=0.0
)
else:
q = tl.load(
q_ptrs,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
other=0.0
)
# loop over k, v and update accumulator
# Iterate over local singletons;
# so we only iterate over blocks within the current window
start_idx_n = offs_w * WINDOW_SIZE
end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_M & EVEN_W:
window_mask = tl.load(
window_mask_ptrs + start_n - start_idx_n
)
else:
window_mask = tl.load(
window_mask_ptrs + start_n - start_idx_n,
mask=(offs_m[:, None] < seqlen_q)
& ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
other=1,
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * qk_scale
qk = tl.where(window_mask, float("-inf"), qk)
m_ij = tl.maximum(tl.max(qk, 1), m_i)
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
p = tl.exp2(qk - m_ij_masked[:, None])
else:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
d_ij = tl.sum(p, 1)
# scale acc_o
prev_scale = tl.exp2(m_i - m_ij_masked)
# # -- update output accumulator --
acc_o = acc_o * prev_scale[:, None]
# update acc_o
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(v.dtype)
acc_o = tl.dot(p, v, acc_o)
# -- update statistics
d_i = d_i * prev_scale + d_ij
m_i = m_ij
if EMPTY_RFA_KV == 0:
# Iterate over RFA chunks
# we only iterate over chunks before the current local singleton window
end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
for start_c in range(0, end_idx_c, BLOCK_N):
start_c = tl.multiple_of(start_c, BLOCK_N)
# -- compute qk ----
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=(start_c + offs_c)[:, None] < nchunks,
other=0.0,
)
else:
rfa_k = tl.load(
rfa_k_ptrs + start_c * stride_rfa_kc,
mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(rfa_k))
# Trying to combine the two masks seem to make the result wrong
if not EVEN_C: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))
if MASK_TYPE == 1:
if EVEN_C & EVEN_M:
chunk_mask = tl.load(
chunk_mask_ptrs + start_c
)
else:
chunk_mask = tl.load(
chunk_mask_ptrs + start_c,
mask=(offs_m[:, None] < seqlen_q) & ((start_c + offs_c)[None, :] < nchunks),
other=1,
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# we assume mask already implies the causal masking
qk = qk * qk_scale
qk = tl.where(chunk_mask, float("-inf"), qk)
m_ij = tl.maximum(tl.max(qk, 1), m_i)
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
p = tl.exp2(qk - m_ij_masked[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
p = tl.exp2(qk * qk_scale - m_ij_masked[:, None])
d_ij = tl.sum(p, 1)
# scale acc_o
prev_scale = tl.exp2(m_i - m_ij_masked)
# # -- update output accumulator --
acc_o = acc_o * prev_scale[:, None]
# update acc_o
# TODO: If we just do "if EVEN_N", there seems to be some race condition ?
if EVEN_C & EVEN_M:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=(start_c + offs_n)[:, None] < nchunks,
other=0.0,
)
else:
rfa_v = tl.load(
rfa_v_ptrs + start_c * stride_rfa_vc,
mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0,
)
p = p.to(rfa_v.dtype)
acc_o = tl.dot(p, rfa_v, acc_o)
# -- update statistics
d_i = d_i * prev_scale + d_ij
m_i = m_ij
# for rows that are all -inf, set d_i to 1.0
d_i = tl.where(d_i == 0.0, 1.0, d_i)
# multiply by log(2)
lse_m = (m_i + tl.math.log2(d_i)) * 0.6931471805599453
acc_o = acc_o / d_i[:, None]
# TODO: understand why rematerialize offsets to save registers?
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = (
Out +
off_b * stride_ob +
off_h * stride_oh +
(offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(
out_ptrs, acc_o
)
else:
tl.store(
out_ptrs, acc_o,
mask=offs_d[None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
out_ptrs, acc_o,
mask=offs_m[:, None] < seqlen_q
)
else:
tl.store(
out_ptrs, acc_o,
mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
lse_ptrs = (
LSE +
off_b * stride_lse_b +
off_h * stride_lse_h +
offs_m
)
if EVEN_M:
tl.store(
lse_ptrs, lse_m,
)
else:
tl.store(
lse_ptrs, lse_m,
mask=offs_m < seqlen_q
)
def triton_eva_agg_fwd(
q, k, v, rfa_k, rfa_v,
window_mask,
chunk_mask,
softmax_scale,
window_size,
chunks_per_window
):
if rfa_k is None and rfa_v is None:
empty_rfa_kv = 1
q, k, v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [q, k, v]
]
else:
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
empty_rfa_kv = 0
q, k, v, rfa_k, rfa_v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [q, k, v, rfa_k, rfa_v]
]
# shape constraints
batch, nheads, seqlen_q, head_dim = q.shape
_, _, seqlen_k, _ = k.shape
if empty_rfa_kv == 0:
nchunks = rfa_k.shape[-2]
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
else:
nchunks = 0
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert k.shape == (batch, nheads, seqlen_k, head_dim)
assert v.shape == (batch, nheads, seqlen_k, head_dim)
assert head_dim <= 128, "We only test head dimensions up to 128"
# assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
assert q.is_cuda and k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
mask_type = 0
if window_mask is not None:
mask_type = 1
assert window_mask.dtype == torch.bool
assert window_mask.is_cuda
assert window_mask.dim() == 4
assert window_mask.shape == (batch, 1, seqlen_q, window_size)
if window_mask.stride(-1) != 1:
window_mask = window_mask.contiguous()
assert chunk_mask is not None
assert chunk_mask.dtype == torch.bool
assert chunk_mask.is_cuda
assert chunk_mask.dim() == 4
assert chunk_mask.shape == (batch, 1, seqlen_q, nchunks)
if chunk_mask.stride(-1) != 1:
chunk_mask = chunk_mask.contiguous()
chunk_mask_strides = (
(chunk_mask.stride(0), chunk_mask.stride(2))
if mask_type == 1 else
(0, 0)
)
window_mask_strides = (
(window_mask.stride(0), window_mask.stride(2))
if mask_type == 1 else
(0, 0)
)
rfa_k_strides = (
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
rfa_v_strides = (
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
o = torch.empty_like(q)
lse = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "fwd")
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK_N"
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
_fwd_eva_agg_kernel[grid](
q,
k,
v,
rfa_k,
rfa_v,
window_mask,
chunk_mask,
o,
lse,
softmax_scale,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
window_mask_strides[0], window_mask_strides[1],
chunk_mask_strides[0], chunk_mask_strides[1],
o.stride(0), o.stride(1), o.stride(2),
lse.stride(0), lse.stride(1),
nheads,
seqlen_q,
seqlen_k,
nchunks,
head_dim,
chunks_per_window,
window_size,
mask_type,
empty_rfa_kv,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
return o, lse
def triton_eva_agg_bwd(
do,
q, k, v, rfa_k, rfa_v,
window_mask, chunk_mask,
o, lse,
dq, dk, dv, d_rfa_k, d_rfa_v,
softmax_scale,
window_size,
chunks_per_window,
empty_rfa_kv,
mask_type,
):
if do.stride(-1) != 1:
do = do.contiguous()
# shape constraints
batch, nheads, seqlen_q, head_dim = q.shape
_, _, seqlen_k, _ = k.shape
if empty_rfa_kv == 0:
nchunks = rfa_k.shape[-2]
assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
assert d_rfa_k.stride(-1) == d_rfa_v.stride(-1) == 1
assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
else:
nchunks = 0
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
assert lse.shape == (batch, nheads, seqlen_q)
assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == rfa_k.stride(-1) == rfa_v.stride(-1) == 1
assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
assert head_dim <= 128, "We only test head dimensions up to 128"
window_mask_strides = (
(window_mask.stride(0), window_mask.stride(2))
if mask_type == 1 else
(0, 0)
)
chunk_mask_strides = (
(chunk_mask.stride(0), chunk_mask.stride(2))
if mask_type == 1 else
(0, 0)
)
rfa_k_strides = (
(rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
rfa_v_strides = (
(rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
d_rfa_k_strides = (
(d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
d_rfa_v_strides = (
(d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2))
if empty_rfa_kv == 0 else
(0, 0, 0)
)
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
do_t_o = torch.sum(do.to(torch.float32) * o.to(torch.float32), dim=-1).to(do.dtype)
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dq")
assert chunks_per_window >= BLOCK_N, "chunks_per_window must be greater than BLOCK"
assert chunks_per_window % BLOCK_N == 0, "chunks_per_window must be a multiple of BLOCK"
grid = lambda META: (
triton.cdiv(seqlen_q, META["BLOCK_M"]),
batch * nheads,
)
_bwd_eva_agg_kernel_dq[grid](
q,
k,
v,
rfa_k,
rfa_v,
window_mask,
chunk_mask,
do,
lse,
do_t_o,
dq,
softmax_scale,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
window_mask_strides[0], window_mask_strides[1],
chunk_mask_strides[0], chunk_mask_strides[1],
do.stride(0), do.stride(1), do.stride(2),
lse.stride(0), lse.stride(1),
do_t_o.stride(0), do_t_o.stride(1),
dq.stride(0), dq.stride(1), dq.stride(2),
nheads,
seqlen_q,
seqlen_k,
nchunks,
head_dim,
chunks_per_window,
window_size,
mask_type,
empty_rfa_kv,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_dkdv")
grid = lambda META: (
triton.cdiv(seqlen_k, META["BLOCK_N"]),
batch * nheads,
)
_bwd_eva_agg_kernel_dkdv[grid](
q,
k,
v,
window_mask,
do,
lse,
do_t_o,
dk,
dv,
softmax_scale,
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
window_mask_strides[0], window_mask_strides[1],
do.stride(0), do.stride(1), do.stride(2),
lse.stride(0), lse.stride(1),
do_t_o.stride(0), do_t_o.stride(1),
dk.stride(0), dk.stride(1), dk.stride(2),
dv.stride(0), dv.stride(1), dv.stride(2),
nheads,
seqlen_q,
seqlen_k,
head_dim,
window_size,
mask_type,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
if empty_rfa_kv == 0:
BLOCK_M, BLOCK_N, num_warps, num_stages = _get_config(q.dtype, head_dim, "bwd_drfa_kv")
grid = lambda META: (
triton.cdiv(nchunks, META["BLOCK_N"]),
batch * nheads,
)
_bwd_eva_agg_kernel_drfa_kv[grid](
q,
rfa_k,
rfa_v,
chunk_mask,
do,
lse,
do_t_o,
d_rfa_k,
d_rfa_v,
softmax_scale,
q.stride(0), q.stride(1), q.stride(2),
rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
chunk_mask_strides[0], chunk_mask_strides[1],
do.stride(0), do.stride(1), do.stride(2),
lse.stride(0), lse.stride(1),
do_t_o.stride(0), do_t_o.stride(1),
d_rfa_k_strides[0], d_rfa_k_strides[1], d_rfa_k_strides[2],
d_rfa_v_strides[0], d_rfa_v_strides[1], d_rfa_v_strides[2],
nheads,
seqlen_q,
nchunks,
head_dim,
chunks_per_window,
window_size,
mask_type,
BLOCK_HEADDIM,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
num_stages=num_stages,
)
class EvaAggFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale=None, window_size=None, chunks_per_window=None):
if rfa_k is None and rfa_v is None:
empty_rfa_kv = 1
else:
assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
empty_rfa_kv = 0
if window_mask is not None:
mask_type = 1
else:
mask_type = 0
o, lse = triton_eva_agg_fwd(
q, k, v, rfa_k, rfa_v, window_mask, chunk_mask, softmax_scale, window_size, chunks_per_window
)
ctx.save_for_backward(q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask)
ctx.softmax_scale = softmax_scale
ctx.window_size = window_size
ctx.chunks_per_window = chunks_per_window
ctx.empty_rfa_kv = empty_rfa_kv
ctx.mask_type = mask_type
return o
@staticmethod
def backward(ctx, do):
q, k, v, o, lse, rfa_k, rfa_v, window_mask, chunk_mask = ctx.saved_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
if ctx.empty_rfa_kv == 0:
d_rfa_k = torch.empty_like(rfa_k)
d_rfa_v = torch.empty_like(rfa_v)
else:
d_rfa_k = None
d_rfa_v = None
triton_eva_agg_bwd(
do,
q,
k,
v,
rfa_k,
rfa_v,
window_mask,
chunk_mask,
o,
lse,
dq,
dk,
dv,
d_rfa_k,
d_rfa_v,
softmax_scale=ctx.softmax_scale,
window_size=ctx.window_size,
chunks_per_window=ctx.chunks_per_window,
empty_rfa_kv=ctx.empty_rfa_kv,
mask_type=ctx.mask_type,
)
return dq, dk, dv, d_rfa_k, d_rfa_v, None, None, None, None, None
def eva_agg_func_triton(
q, k, v, rfa_k, rfa_v,
window_mask, chunk_mask,
softmax_scale=None, window_size=None, chunks_per_window=None,
):
return EvaAggFunc.apply(
q, k, v, rfa_k, rfa_v,
window_mask, chunk_mask,
softmax_scale, window_size, chunks_per_window,
)