|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.ops.utils.op import exp, gather |
|
from fla.utils import is_gather_supported, use_cuda_graph |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages) |
|
for BK in [32, 64] |
|
for num_warps in [2, 4, 8, 16] |
|
for num_stages in [2, 3, 4] |
|
], |
|
key=['BC', 'K'], |
|
use_cuda_graph=use_cuda_graph, |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_dplr_fwd_A_kernel_intra_sub_inter( |
|
q, |
|
k, |
|
a, |
|
b, |
|
gi, |
|
ge, |
|
Aqk, |
|
Aqb, |
|
Aab, |
|
Aak, |
|
offsets, |
|
indices, |
|
scale: tl.constexpr, |
|
T, |
|
H: tl.constexpr, |
|
K: tl.constexpr, |
|
BT: tl.constexpr, |
|
BC: tl.constexpr, |
|
BK: tl.constexpr, |
|
NC: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
): |
|
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
i_i, i_j = i_c // NC, i_c % NC |
|
if USE_OFFSETS: |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
else: |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
if i_t * BT + i_i * BC >= T: |
|
return |
|
if i_i <= i_j: |
|
return |
|
|
|
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32) |
|
b_Aqb = tl.zeros([BC, BC], dtype=tl.float32) |
|
b_Aab = tl.zeros([BC, BC], dtype=tl.float32) |
|
b_Aak = tl.zeros([BC, BC], dtype=tl.float32) |
|
for i_k in range(tl.cdiv(K, BK)): |
|
o_k = i_k * BK + tl.arange(0, BK) |
|
m_k = o_k < K |
|
|
|
if HEAD_FIRST: |
|
p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) |
|
p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) |
|
p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) |
|
p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK) |
|
else: |
|
p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) |
|
p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) |
|
p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) |
|
p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)) |
|
p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k |
|
|
|
b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32) |
|
|
|
b_q = tl.load(p_q, boundary_check=(0, 1)) |
|
b_a = tl.load(p_a, boundary_check=(0, 1)) |
|
b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1)) |
|
b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1)) |
|
b_ag = b_a * exp(b_gq_e - b_gn[None, :]) |
|
b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale |
|
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
b_b = tl.load(p_b, boundary_check=(0, 1)) |
|
b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32) |
|
tmp = exp(b_gn[:, None] - b_gk) |
|
b_kg = b_k * tmp |
|
b_bg = b_b * tmp |
|
|
|
b_Aab += tl.dot(b_ag, b_bg) |
|
b_Aak += tl.dot(b_ag, b_kg) |
|
b_Aqk += tl.dot(b_qg, b_kg) |
|
b_Aqb += tl.dot(b_qg, b_bg) |
|
|
|
if HEAD_FIRST: |
|
p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
else: |
|
p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) |
|
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=num_warps, num_stages=num_stages) |
|
for num_warps in [2, 4, 8, 16, 32] |
|
for num_stages in [2, 3, 4] |
|
], |
|
key=['BK', 'BT'], |
|
use_cuda_graph=use_cuda_graph, |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_dplr_fwd_A_kernel_intra_sub_intra( |
|
q, |
|
k, |
|
a, |
|
b, |
|
gi, |
|
ge, |
|
qg, |
|
kg, |
|
ag, |
|
bg, |
|
Aqk, |
|
Aqb, |
|
Aab, |
|
Aak, |
|
offsets, |
|
indices, |
|
scale: tl.constexpr, |
|
T, |
|
H: tl.constexpr, |
|
K: tl.constexpr, |
|
BT: tl.constexpr, |
|
BC: tl.constexpr, |
|
BK: tl.constexpr, |
|
NC: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
GATHER_SUPPORTED: tl.constexpr |
|
): |
|
i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
i_j = i_i |
|
if USE_OFFSETS: |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
else: |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
if i_t * BT + i_i * BC >= T: |
|
return |
|
|
|
o_i = tl.arange(0, BC) |
|
o_k = tl.arange(0, BK) |
|
m_k = o_k < K |
|
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T |
|
last_idx = min((i_t+1) * BT, T) - 1 |
|
if HEAD_FIRST: |
|
o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC |
|
p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK) |
|
b_g_last = tl.load(p_g_last, mask=m_k, other=0) |
|
|
|
p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
else: |
|
o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC |
|
p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK) |
|
b_g_last = tl.load(p_g_last, mask=m_k, other=0) |
|
p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0)) |
|
|
|
b_q = tl.load(p_q, boundary_check=(0, 1)) |
|
b_q = b_q * scale |
|
b_k = tl.load(p_k, boundary_check=(0, 1)) |
|
b_a = tl.load(p_a, boundary_check=(0, 1)) |
|
b_b = tl.load(p_b, boundary_check=(0, 1)) |
|
b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32) |
|
b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32) |
|
|
|
|
|
g_exp = exp(b_gi) |
|
g_exp_inv = exp(-b_gi + b_g_last[None, :]) |
|
b_qg = b_q * g_exp |
|
b_kg = b_k * g_exp_inv |
|
b_bg = b_b * g_exp_inv |
|
b_ag = b_a * exp(b_ge) |
|
tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) |
|
|
|
|
|
b_q = b_q.to(b_k.dtype) |
|
|
|
for j in range(0, min(BC, T - i_t * BT - i_i * BC)): |
|
|
|
if GATHER_SUPPORTED: |
|
row_idx = tl.full([1, BK], j, dtype=tl.int16) |
|
|
|
b_k_j = gather(b_k, row_idx, axis=0) |
|
b_gk_j = gather(b_gi, row_idx, axis=0) |
|
b_b_j = gather(b_b, row_idx, axis=0) |
|
else: |
|
mask = tl.arange(0, BC) == j |
|
b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :] |
|
b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :] |
|
b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :] |
|
mask = tl.arange(0, BC) == j |
|
tmp = exp(b_gi - b_gk_j) |
|
b_A_qk = tl.sum(b_q * b_k_j * tmp, 1) |
|
b_A_qk = tl.where(o_i >= j, b_A_qk, 0.) |
|
b_A_qb = tl.sum(b_q * b_b_j * tmp, 1) |
|
b_A_qb = tl.where(o_i >= j, b_A_qb, 0.) |
|
tmp2 = exp(b_ge - b_gk_j) |
|
b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1) |
|
b_A_ak = tl.where(o_i > j, b_A_ak, 0.) |
|
b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1) |
|
b_A_ab = tl.where(o_i > j, b_A_ab, 0.) |
|
tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) |
|
tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) |
|
tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) |
|
tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A) |
|
|
|
|
|
def chunk_fwd_intra_dplr_fn( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
a: torch.Tensor, |
|
b: torch.Tensor, |
|
gi: torch.Tensor, |
|
ge: torch.Tensor, |
|
scale: float, |
|
chunk_size: int, |
|
offsets: Optional[torch.LongTensor] = None, |
|
indices: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
): |
|
if head_first: |
|
B, H, T, K = k.shape |
|
else: |
|
B, T, H, K = k.shape |
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) |
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
BC = min(16, BT) |
|
NC = triton.cdiv(BT, BC) |
|
|
|
Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype) |
|
Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype) |
|
|
|
Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) |
|
Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float) |
|
grid = (NT, NC * NC, B * H) |
|
|
|
chunk_dplr_fwd_A_kernel_intra_sub_inter[grid]( |
|
q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak, |
|
offsets=offsets, indices=indices, |
|
scale=scale, |
|
T=T, H=H, K=K, BT=BT, BC=BC, NC=NC, |
|
HEAD_FIRST=head_first |
|
) |
|
grid = (NT, NC, B * H) |
|
BK = triton.next_power_of_2(K) |
|
qg = torch.empty_like(q) |
|
kg = torch.empty_like(k, dtype=q.dtype) |
|
ag = torch.empty_like(a, dtype=q.dtype) |
|
bg = torch.empty_like(b, dtype=q.dtype) |
|
chunk_dplr_fwd_A_kernel_intra_sub_intra[grid]( |
|
q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak, |
|
qg=qg, kg=kg, ag=ag, bg=bg, |
|
offsets=offsets, indices=indices, |
|
scale=scale, |
|
T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC, |
|
GATHER_SUPPORTED=is_gather_supported |
|
) |
|
return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg |
|
|