EvaByte-SFT / eva_prep_kv_kernel.py
linzheng's picture
Update model and kernels for training support
d927f57 verified
import math
import torch
import triton
import triton.language as tl
@triton.heuristics(
{
"EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_eva_prep_kv_kernel(
K, # [b, h, n, d]
V, # [b, h, n, d]
PARAM_MU, # [1, h, 1, 1, d]
PARAM_PHI, # [1, h, 1, 1, d]
Mask, # [b, h, n, 1]
Out_RFA_K, # [b, h, c, d]
Out_RFA_V, # [b, h, c, d]
softmax_scale,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_mu_h,
stride_phi_h,
stride_mb, stride_mn,
stride_ok_b, stride_ok_h, stride_ok_c,
stride_ov_b, stride_ov_h, stride_ov_c,
nheads,
seqlen,
nchunks,
headdim,
CHUNKS_PER_BLOCK: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_n = tl.program_id(0)
offs_bh = tl.program_id(1)
offs_h = offs_bh % nheads
offs_b = offs_bh // nheads
# initialize offsets
# we load BLOCK_N keys and values each time, and
# reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
offs_m = tl.arange(0, CHUNK_SIZE)
offs_d = tl.arange(0, BLOCK_HEADDIM)
k_ptrs = (
K +
offs_b * stride_kb +
offs_h * stride_kh +
(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) * stride_kn +
offs_d[None, None, :]
)
)
v_ptrs = (
V +
offs_b * stride_vb +
offs_h * stride_vh +
(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) * stride_vn +
offs_d[None, None, :]
)
)
param_mu_ptrs = (
PARAM_MU +
offs_h * stride_mu_h +
offs_d[None, None, :]
)
param_phi_ptrs = (
PARAM_PHI +
offs_h * stride_phi_h +
offs_d[None, None, :]
)
log2e = 1.4426950408889634
if MASK_TYPE == 1:
m_ptrs = (
Mask +
offs_b * stride_mb +
(
(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) * stride_mn
)
)
if EVEN_N:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs
)
else:
k = tl.load(
k_ptrs,
mask=offs_d[None, None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs,
mask=(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen,
other=0.0
)
else:
k = tl.load(
k_ptrs,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
) & (offs_d[None, None, :] < headdim),
other=0.0
)
param_mu = tl.load(param_mu_ptrs).to(k.dtype)
rfa_k_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
rfa_k_c_w += tl.sum(k * param_mu, axis=-1)
rfa_k_c_w *= log2e
if MASK_TYPE == 1:
if EVEN_N:
mask = tl.load(
m_ptrs
)
else:
mask = tl.load(
m_ptrs,
mask=(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) < seqlen,
other=1,
)
rfa_k_c_w = tl.where(mask, float("-inf"), rfa_k_c_w)
m_rfa_k_c_w = tl.max(rfa_k_c_w, axis=-1)
masked_out_rows_rfa_k = (m_rfa_k_c_w == float("-inf"))
m_rfa_k_c_w_masked = tl.where(masked_out_rows_rfa_k, 0, m_rfa_k_c_w)
rfa_k_c_w = tl.exp2(rfa_k_c_w - m_rfa_k_c_w_masked[:, None])
denom_k = tl.sum(rfa_k_c_w, axis=-1)
denom_k = tl.where(denom_k == 0.0, 1.0, denom_k)
rfa_k_c_w = rfa_k_c_w / denom_k[:, None]
rfa_k_c = tl.sum(k * rfa_k_c_w[:, :, None].to(k.dtype), axis=-2)
# TODO: understand why rematerialize offsets to save registers?
offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
out_rfa_k_ptrs = (
Out_RFA_K +
offs_b * stride_ok_b +
offs_h * stride_ok_h +
(offs_out_c[:, None] * stride_ok_c + offs_d[None, :])
)
if EVEN_N:
if EVEN_HEADDIM:
tl.store(
out_rfa_k_ptrs, rfa_k_c
)
else:
tl.store(
out_rfa_k_ptrs, rfa_k_c,
mask=offs_d[None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
out_rfa_k_ptrs, rfa_k_c,
mask=offs_out_c[:, None] < nchunks
)
else:
tl.store(
out_rfa_k_ptrs, rfa_k_c,
mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
)
param_phi = tl.load(param_phi_ptrs).to(k.dtype)
rfa_v_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
rfa_v_c_w += tl.sum(k * param_phi, axis=-1)
rfa_v_c_w -= (0.5 * tl.sum(k * k, axis=-1))
rfa_v_c_w *= log2e * softmax_scale
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
rfa_v_c_w += tl.where(
(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) < seqlen,
0,
float("-inf")
)
if MASK_TYPE == 1:
rfa_v_c_w = tl.where(mask, float("-inf"), rfa_v_c_w)
if EVEN_N:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs
)
else:
v = tl.load(
v_ptrs,
mask=offs_d[None, None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs,
mask=(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen,
other=0.0
)
else:
v = tl.load(
v_ptrs,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
) & (offs_d[None, None, :] < headdim),
other=0.0
)
m_rfa_v_c_w = tl.max(rfa_v_c_w, axis=-1)
masked_out_rows_rfa_v = (m_rfa_v_c_w == float("-inf"))
m_rfa_v_c_w_masked = tl.where(masked_out_rows_rfa_v, 0, m_rfa_v_c_w)
rfa_v_c_w = tl.exp2(rfa_v_c_w - m_rfa_v_c_w_masked[:, None])
denom_v = tl.sum(rfa_v_c_w, axis=-1)
denom_v = tl.where(denom_v == 0.0, 1.0, denom_v)
rfa_v_c_w = rfa_v_c_w / denom_v[:, None]
rfa_v_c = tl.sum(v * rfa_v_c_w[:, :, None].to(v.dtype), axis=-2)
offs_out_c = start_n * CHUNKS_PER_BLOCK + tl.arange(0, CHUNKS_PER_BLOCK)
out_rfa_v_ptrs = (
Out_RFA_V +
offs_b * stride_ov_b +
offs_h * stride_ov_h +
(offs_out_c[:, None] * stride_ov_c + offs_d[None, :])
)
if EVEN_N:
if EVEN_HEADDIM:
tl.store(
out_rfa_v_ptrs, rfa_v_c
)
else:
tl.store(
out_rfa_v_ptrs, rfa_v_c,
mask=offs_d[None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
out_rfa_v_ptrs, rfa_v_c,
mask=offs_out_c[:, None] < nchunks
)
else:
tl.store(
out_rfa_v_ptrs, rfa_v_c,
mask=(offs_out_c[:, None] < nchunks) & (offs_d[None, :] < headdim)
)
@triton.heuristics(
{
"EVEN_N": lambda args: args["seqlen"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _bwd_eva_prep_kv_kernel(
RFA_K, # [b, h, c, d]
RFA_V, # [b, h, c, d]
K, # [b, h, n, d]
V, # [b, h, n, d]
PARAM_MU, # [1, h, 1, 1, d]
PARAM_PHI, # [1, h, 1, 1, d]
Mask, # [b, h, n, 1]
D_RFA_K, # [b, h, c, d]
D_RFA_V, # [b, h, c, d]
D_K, # [b, h, n, d]
D_V, # [b, h, n, d]
D_PARAM_MU_PARTIAL, # [b, h, g, d]
D_PARAM_PHI_PARTIAL, # [b, h, g, d]
softmax_scale,
stride_rfa_k_b, stride_rfa_k_h, stride_rfa_k_c,
stride_rfa_v_b, stride_rfa_v_h, stride_rfa_v_c,
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_mu_h,
stride_phi_h,
stride_mb, stride_mn,
stride_d_rfa_k_b, stride_d_rfa_k_h, stride_d_rfa_k_c,
stride_d_rfa_v_b, stride_d_rfa_v_h, stride_d_rfa_v_c,
stride_d_k_b, stride_d_k_h, stride_d_k_n,
stride_d_v_b, stride_d_v_h, stride_d_v_n,
stride_d_mu_b, stride_d_mu_h, stride_d_mu_g,
stride_d_phi_b, stride_d_phi_h, stride_d_phi_g,
nheads,
seqlen,
nchunks,
headdim,
CHUNKS_PER_BLOCK: tl.constexpr,
CHUNK_SIZE: tl.constexpr,
MASK_TYPE: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_n = tl.program_id(0)
offs_bh = tl.program_id(1)
offs_h = offs_bh % nheads
offs_b = offs_bh // nheads
# initialize offsets
# we load BLOCK_N keys and values each time, and
# reshape it to [CHUNKS_PER_BLOCK, CHUNK_SIZE]
offs_c = tl.arange(0, CHUNKS_PER_BLOCK)
offs_m = tl.arange(0, CHUNK_SIZE)
offs_d = tl.arange(0, BLOCK_HEADDIM)
offs_rfa_c = start_n * CHUNKS_PER_BLOCK + offs_c
k_ptrs = (
K +
offs_b * stride_kb +
offs_h * stride_kh +
(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) * stride_kn +
offs_d[None, None, :]
)
)
rfa_k_ptrs = (
RFA_K +
offs_b * stride_rfa_k_b +
offs_h * stride_rfa_k_h +
(offs_rfa_c[:, None] * stride_rfa_k_c + offs_d[None, :])
)
rfa_v_ptrs = (
RFA_V +
offs_b * stride_rfa_v_b +
offs_h * stride_rfa_v_h +
(offs_rfa_c[:, None] * stride_rfa_v_c + offs_d[None, :])
)
d_rfa_k_ptrs = (
D_RFA_K +
offs_b * stride_d_rfa_k_b +
offs_h * stride_d_rfa_k_h +
(offs_rfa_c[:, None] * stride_d_rfa_k_c + offs_d[None, :])
)
d_rfa_v_ptrs = (
D_RFA_V +
offs_b * stride_d_rfa_v_b +
offs_h * stride_d_rfa_v_h +
(offs_rfa_c[:, None] * stride_d_rfa_v_c + offs_d[None, :])
)
param_mu_ptrs = (
PARAM_MU +
offs_h * stride_mu_h +
offs_d[None, None, :]
)
param_phi_ptrs = (
PARAM_PHI +
offs_h * stride_phi_h +
offs_d[None, None, :]
)
log2e = 1.4426950408889634
if MASK_TYPE == 1:
m_ptrs = (
Mask +
offs_b * stride_mb +
(
(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) * stride_mn
)
)
if EVEN_N:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs
)
else:
k = tl.load(
k_ptrs,
mask=offs_d[None, None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs,
mask=(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen,
other=0.0
)
else:
k = tl.load(
k_ptrs,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
) & (offs_d[None, None, :] < headdim),
other=0.0
)
if EVEN_N:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs
)
else:
rfa_k = tl.load(
rfa_k_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_k = tl.load(
rfa_k_ptrs,
mask=offs_rfa_c[:, None] < nchunks,
other=0.0
)
else:
rfa_k = tl.load(
rfa_k_ptrs,
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0
)
if EVEN_N:
if EVEN_HEADDIM:
d_rfa_k = tl.load(
d_rfa_k_ptrs
)
else:
d_rfa_k = tl.load(
d_rfa_k_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
d_rfa_k = tl.load(
d_rfa_k_ptrs,
mask=offs_rfa_c[:, None] < nchunks,
other=0.0
)
else:
d_rfa_k = tl.load(
d_rfa_k_ptrs,
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0
)
param_mu = tl.load(param_mu_ptrs).to(k.dtype)
mu_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
mu_c_w += tl.sum(k * param_mu, axis=-1)
mu_c_w *= log2e
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
mu_c_w += tl.where(
(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) < seqlen,
0,
float("-inf")
)
if MASK_TYPE == 1:
if EVEN_N:
mask = tl.load(
m_ptrs
)
else:
mask = tl.load(
m_ptrs,
mask=(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) < seqlen,
other=1,
)
mu_c_w = tl.where(mask, float("-inf"), mu_c_w)
# [c, w]
m_mu_c_w = tl.max(mu_c_w, axis=-1)
masked_out_rows_mu = (m_mu_c_w == float("-inf"))
m_mu_c_w_masked = tl.where(masked_out_rows_mu, 0, m_mu_c_w)
mu_c_w = tl.exp2(mu_c_w - m_mu_c_w_masked[:, None])
denom_mu = tl.sum(mu_c_w, axis=-1)
denom_mu = tl.where(denom_mu == 0.0, 1.0, denom_mu)
mu_tilde_c_w = mu_c_w / denom_mu[:, None]
mu_tilde_c_w = mu_tilde_c_w.to(k.dtype)
# [c, d] [c, w, d] -> [c, w]
d_mu_tilde_c_w = tl.sum(d_rfa_k[:, None, :] * k, axis=-1)
# [c, d] [c, d] -> [c]
d_out_rfa_k_t_rfa_k = tl.sum(d_rfa_k * rfa_k, axis=-1)[:, None]
d_mu_c_w = (d_mu_tilde_c_w - d_out_rfa_k_t_rfa_k) * mu_tilde_c_w
# [c, w] [c, w, d] -> [d]
d_param_mu = tl.sum(tl.sum(d_mu_c_w[:, :, None] * k, axis=0), axis=0)
# [c, w] [c, d] + [c, w] [1, 1, d] -> [c, w, d]
d_k = mu_tilde_c_w[:, :, None] * d_rfa_k[:, None, :] + d_mu_c_w[:, :, None] * param_mu
d_param_mu_partial_ptrs = (
D_PARAM_MU_PARTIAL +
offs_b * stride_d_mu_b +
offs_h * stride_d_mu_h +
start_n * stride_d_mu_g +
offs_d
)
if EVEN_HEADDIM:
tl.store(
d_param_mu_partial_ptrs, d_param_mu
)
else:
tl.store(
d_param_mu_partial_ptrs, d_param_mu,
mask=offs_d < headdim
)
v_ptrs = (
V +
offs_b * stride_vb +
offs_h * stride_vh +
(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) * stride_vn +
offs_d[None, None, :]
)
)
if EVEN_N:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs
)
else:
v = tl.load(
v_ptrs,
mask=offs_d[None, None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs,
mask=(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen,
other=0.0
)
else:
v = tl.load(
v_ptrs,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
) & (offs_d[None, None, :] < headdim),
other=0.0
)
if EVEN_N:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs
)
else:
rfa_v = tl.load(
rfa_v_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
rfa_v = tl.load(
rfa_v_ptrs,
mask=offs_rfa_c[:, None] < nchunks,
other=0.0
)
else:
rfa_v = tl.load(
rfa_v_ptrs,
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0
)
if EVEN_N:
if EVEN_HEADDIM:
d_rfa_v = tl.load(
d_rfa_v_ptrs
)
else:
d_rfa_v = tl.load(
d_rfa_v_ptrs,
mask=offs_d[None, :] < headdim,
other=0.0
)
else:
if EVEN_HEADDIM:
d_rfa_v = tl.load(
d_rfa_v_ptrs,
mask=offs_rfa_c[:, None] < nchunks,
other=0.0
)
else:
d_rfa_v = tl.load(
d_rfa_v_ptrs,
mask=(offs_rfa_c[:, None] < nchunks) & (offs_d[None, :] < headdim),
other=0.0
)
param_phi = tl.load(param_phi_ptrs).to(k.dtype)
phi_c_w = tl.zeros([CHUNKS_PER_BLOCK, CHUNK_SIZE], dtype=tl.float32)
phi_c_w += tl.sum(k * param_phi, axis=-1)
phi_c_w -= (0.5 * tl.sum(k * k, axis=-1))
phi_c_w *= log2e * softmax_scale
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
phi_c_w += tl.where(
(
start_n * BLOCK_N +
offs_c[:, None] * CHUNK_SIZE +
offs_m[None, :]
) < seqlen,
0,
float("-inf")
)
if MASK_TYPE == 1:
phi_c_w = tl.where(mask, float("-inf"), phi_c_w)
m_phi_c_w = tl.max(phi_c_w, axis=-1)
masked_out_rows_phi = (m_phi_c_w == float("-inf"))
m_phi_c_w_masked = tl.where(masked_out_rows_phi, 0, m_phi_c_w)
phi_c_w = tl.exp2(phi_c_w - m_phi_c_w_masked[:, None])
denom_phi = tl.sum(phi_c_w, axis=-1)
denom_phi = tl.where(denom_phi == 0.0, 1.0, denom_phi)
phi_tilde_c_w = phi_c_w / denom_phi[:, None]
# phi_c_w = tl.exp2(phi_c_w - tl.max(phi_c_w, axis=-1)[:, None])
# phi_tilde_c_w = phi_c_w / tl.sum(phi_c_w, axis=-1)[:, None]
phi_tilde_c_w = phi_tilde_c_w.to(k.dtype)
d_phi_tilde_c_w = tl.sum(d_rfa_v[:, None, :] * v, axis=-1)
d_out_rfa_v_t_rfa_v = tl.sum(d_rfa_v * rfa_v, axis=-1)[:, None]
d_phi_c_w = (d_phi_tilde_c_w.to(tl.float32) - d_out_rfa_v_t_rfa_v.to(tl.float32)) * phi_tilde_c_w
d_param_phi = tl.sum(tl.sum(d_phi_c_w[:, :, None] * k * softmax_scale, axis=0), axis=0)
d_v = phi_tilde_c_w[:, :, None] * d_rfa_v[:, None, :]
# [c, w, d] + [c, w] * [1, 1, d] - [c, w, d]
d_k = d_k + softmax_scale * d_phi_c_w[:, :, None] * (param_phi - k)
d_k_ptrs = (
D_K +
offs_b * stride_d_k_b +
offs_h * stride_d_k_h +
(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) * stride_d_k_n +
offs_d[None, None, :]
)
)
d_v_ptrs = (
D_V +
offs_b * stride_d_v_b +
offs_h * stride_d_v_h +
(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) * stride_d_v_n +
offs_d[None, None, :]
)
)
if EVEN_N:
if EVEN_HEADDIM:
tl.store(
d_k_ptrs, d_k
)
tl.store(
d_v_ptrs, d_v
)
else:
tl.store(
d_k_ptrs, d_k,
mask=offs_d[None, None, :] < headdim
)
tl.store(
d_v_ptrs, d_v,
mask=offs_d[None, None, :] < headdim
)
else:
if EVEN_HEADDIM:
tl.store(
d_k_ptrs, d_k,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
),
)
tl.store(
d_v_ptrs, d_v,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
),
)
else:
tl.store(
d_k_ptrs, d_k,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
) & (offs_d[None, None, :] < headdim),
)
tl.store(
d_v_ptrs, d_v,
mask=(
(
start_n * BLOCK_N +
offs_c[:, None, None] * CHUNK_SIZE +
offs_m[None, :, None]
) < seqlen
) & (offs_d[None, None, :] < headdim),
)
d_param_phi_partial_ptrs = (
D_PARAM_PHI_PARTIAL +
offs_b * stride_d_phi_b +
offs_h * stride_d_phi_h +
start_n * stride_d_phi_g +
offs_d
)
if EVEN_HEADDIM:
tl.store(
d_param_phi_partial_ptrs, d_param_phi
)
else:
tl.store(
d_param_phi_partial_ptrs, d_param_phi,
mask=offs_d < headdim
)
def triton_eva_prep_kv_fwd(k, v, param_mu, param_phi, mask, softmax_scale, chunksize):
k, v, param_mu, param_phi = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [k, v, param_mu, param_phi]
]
# shape constraints
batch, nheads, seqlen, head_dim = k.shape
assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
nchunks = seqlen // chunksize
assert k.shape == (batch, nheads, seqlen, head_dim)
assert v.shape == (batch, nheads, seqlen, head_dim)
assert param_mu.shape == (1, nheads, 1, 1, head_dim)
assert param_phi.shape == (1, nheads, 1, 1, head_dim)
assert head_dim <= 128, "We only test head dimensions up to 128"
assert k.dtype == v.dtype == param_mu.dtype == param_phi.dtype, "All tensors must have the same type"
assert k.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
assert k.is_cuda and v.is_cuda
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
mask_type = 0
if mask is not None:
mask_type = 1
assert mask.dtype == torch.bool
assert mask.is_cuda
assert mask.dim() == 4
assert mask.shape == (batch, 1, seqlen, 1)
if mask.stride(-1) != 1:
mask = mask.contiguous()
mask_strides = (
(mask.stride(0), mask.stride(2))
if mask_type == 1 else
(0, 0)
)
out_rfa_k = torch.empty((batch, nheads, nchunks, head_dim), dtype=k.dtype, device=k.device)
out_rfa_v = torch.empty((batch, nheads, nchunks, head_dim), dtype=v.dtype, device=v.device)
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
BLOCK = 128
num_warps = 4 if head_dim <= 64 else 8
assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
chunks_per_block = BLOCK // chunksize
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_N"]), batch * nheads)
_fwd_eva_prep_kv_kernel[grid](
k,
v,
param_mu,
param_phi,
mask,
out_rfa_k,
out_rfa_v,
softmax_scale,
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
param_mu.stride(1),
param_phi.stride(1),
mask_strides[0], mask_strides[1],
out_rfa_k.stride(0), out_rfa_k.stride(1), out_rfa_k.stride(2),
out_rfa_v.stride(0), out_rfa_v.stride(1), out_rfa_v.stride(2),
nheads,
seqlen,
nchunks,
head_dim,
chunks_per_block,
chunksize,
mask_type,
BLOCK_HEADDIM,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return out_rfa_k, out_rfa_v
def triton_eva_prep_kv_bwd(
d_rfa_k, d_rfa_v,
k, v, param_mu, param_phi,
mask,
rfa_k, rfa_v,
d_k, d_v, d_param_mu, d_param_phi,
softmax_scale,
mask_type,
chunksize
):
d_rfa_k, d_rfa_v = [
x if x.stride(-1) == 1 else x.contiguous()
for x in [d_rfa_k, d_rfa_v]
]
# shape constraints
batch, nheads, seqlen, head_dim = k.shape
assert seqlen % chunksize == 0, "seqlen must be divisible by chunksize"
nchunks = seqlen // chunksize
softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)
mask_strides = (
(mask.stride(0), mask.stride(2))
if mask_type == 1 else
(0, 0)
)
BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
BLOCK = 128
num_warps = 4 if head_dim <= 64 else 8
assert (BLOCK > chunksize) & (BLOCK % chunksize) == 0, "BLOCK must be divisible by chunksize"
chunks_per_block = BLOCK // chunksize
partial_groups = triton.cdiv(seqlen, BLOCK)
d_param_mu_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
d_param_phi_partial = torch.zeros((batch, nheads, partial_groups, head_dim), dtype=torch.float32, device=d_rfa_k.device)
grid = lambda META: (partial_groups, batch * nheads)
_bwd_eva_prep_kv_kernel[grid](
rfa_k, # [b, h, c, d]
rfa_v, # [b, h, c, d]
k, # [b, h, n, d]
v, # [b, h, n, d]
param_mu, # [1, h, 1, 1, d]
param_phi, # [1, h, 1, 1, d]
mask, # [b, h, n, 1]
d_rfa_k, # [b, h, c, d]
d_rfa_v, # [b, h, c, d]
d_k, # [b, h, n, d]
d_v, # [b, h, n, d]
d_param_mu_partial, # [b, h, g, d]
d_param_phi_partial, # [b, h, g, d]
softmax_scale,
rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2),
rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
param_mu.stride(1),
param_phi.stride(1),
mask_strides[0], mask_strides[1],
d_rfa_k.stride(0), d_rfa_k.stride(1), d_rfa_k.stride(2),
d_rfa_v.stride(0), d_rfa_v.stride(1), d_rfa_v.stride(2),
d_k.stride(0), d_k.stride(1), d_k.stride(2),
d_v.stride(0), d_v.stride(1), d_v.stride(2),
d_param_mu_partial.stride(0), d_param_mu_partial.stride(1), d_param_mu_partial.stride(2),
d_param_phi_partial.stride(0), d_param_phi_partial.stride(1), d_param_phi_partial.stride(2),
nheads,
seqlen,
nchunks,
head_dim,
chunks_per_block,
chunksize,
mask_type,
BLOCK_HEADDIM,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
d_param_mu.copy_(d_param_mu_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_mu.dtype))
d_param_phi.copy_(d_param_phi_partial.sum(dim=(0, -2), keepdim=True).unsqueeze(-2).to(d_param_phi.dtype))
class EvaPrepKVFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, k, v, param_mu, param_phi, mask, softmax_scale=None, chunksize=None):
if mask is not None:
mask_type = 1
else:
mask_type = 0
rfa_k, rfa_v = triton_eva_prep_kv_fwd(
k, v, param_mu, param_phi, mask, softmax_scale, chunksize
)
ctx.save_for_backward(k, v, param_mu, param_phi, mask, rfa_k, rfa_v)
ctx.softmax_scale = softmax_scale
ctx.chunksize = chunksize
ctx.mask_type = mask_type
return rfa_k, rfa_v
@staticmethod
def backward(ctx, d_rfa_k, d_rfa_v):
k, v, param_mu, param_phi, mask, rfa_k, rfa_v = ctx.saved_tensors
d_k = torch.empty_like(k)
d_v = torch.empty_like(v)
d_param_mu = torch.empty_like(param_mu)
d_param_phi = torch.empty_like(param_phi)
triton_eva_prep_kv_bwd(
d_rfa_k, d_rfa_v,
k, v, param_mu, param_phi,
mask,
rfa_k, rfa_v,
d_k, d_v, d_param_mu, d_param_phi,
ctx.softmax_scale,
ctx.mask_type,
ctx.chunksize
)
return d_k, d_v, d_param_mu, d_param_phi, None, None, None
def eva_prep_kv_func_triton(
k, v,
param_mu, param_phi,
mask,
softmax_scale=None, chunksize=None
):
return EvaPrepKVFunc.apply(
k, v,
param_mu, param_phi,
mask,
softmax_scale, chunksize
)