|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.utils import check_shared_mem, use_cuda_graph |
|
|
|
BK_LIST = [32, 64, 128] if check_shared_mem() else [16, 32] |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None, |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BK': BK, 'BV': BV}, num_warps=num_warps, num_stages=num_stages) |
|
for BK in BK_LIST |
|
for BV in BK_LIST |
|
for num_warps in [2, 4, 8, 16, 32] |
|
for num_stages in [2, 3, 4] |
|
], |
|
key=['BT'], |
|
use_cuda_graph=use_cuda_graph, |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_dplr_fwd_kernel_o( |
|
qg, |
|
v, |
|
v_new, |
|
A_qk, |
|
A_qb, |
|
h, |
|
o, |
|
offsets, |
|
indices, |
|
T, |
|
H: tl.constexpr, |
|
K: tl.constexpr, |
|
V: tl.constexpr, |
|
BT: tl.constexpr, |
|
BK: tl.constexpr, |
|
BV: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
): |
|
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
|
|
if USE_OFFSETS: |
|
i_tg = i_t |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
NT = tl.cdiv(T, BT) |
|
else: |
|
NT = tl.cdiv(T, BT) |
|
i_tg = i_b * NT + i_t |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
b_o = tl.zeros([BT, BV], dtype=tl.float32) |
|
for i_k in range(tl.cdiv(K, BK)): |
|
if HEAD_FIRST: |
|
p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
p_h = tl.make_block_ptr(h + (i_bh * NT + i_t) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) |
|
else: |
|
p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) |
|
p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K*V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) |
|
b_qg = tl.load(p_qg, boundary_check=(0, 1)) |
|
b_h = tl.load(p_h, boundary_check=(0, 1)) |
|
b_o += tl.dot(b_qg, b_h) |
|
|
|
if HEAD_FIRST: |
|
p_Aqk = tl.make_block_ptr(A_qk + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
p_Aqb = tl.make_block_ptr(A_qb + i_bh * T*BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
p_v = tl.make_block_ptr(v + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
p_v_new = tl.make_block_ptr(v_new + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
p_o = tl.make_block_ptr(o + i_bh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
else: |
|
p_Aqk = tl.make_block_ptr(A_qk + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
p_Aqb = tl.make_block_ptr(A_qb + (bos * H + i_h) * BT, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) |
|
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
p_v_new = tl.make_block_ptr(v_new + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) |
|
|
|
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] |
|
b_Aqk = tl.load(p_Aqk, boundary_check=(0, 1)) |
|
b_Aqb = tl.load(p_Aqb, boundary_check=(0, 1)) |
|
b_Aqk = tl.where(m_s, b_Aqk, 0) |
|
b_Aqb = tl.where(m_s, b_Aqb, 0) |
|
b_v = tl.load(p_v, boundary_check=(0, 1)) |
|
b_v_new = tl.load(p_v_new, boundary_check=(0, 1)) |
|
b_o = b_o + tl.dot(b_Aqk.to(b_v.dtype), b_v) + tl.dot(b_Aqb.to(b_v_new.dtype), b_v_new) |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
def chunk_dplr_fwd_o( |
|
qg: torch.Tensor, |
|
v: torch.Tensor, |
|
v_new: torch.Tensor, |
|
A_qk: torch.Tensor, |
|
A_qb: torch.Tensor, |
|
h: torch.Tensor, |
|
offsets: Optional[torch.LongTensor] = None, |
|
indices: Optional[torch.LongTensor] = None, |
|
head_first: bool = True, |
|
chunk_size: int = 64 |
|
) -> torch.Tensor: |
|
if head_first: |
|
B, H, T, K, V = *qg.shape, v.shape[-1] |
|
else: |
|
B, T, H, K, V = *qg.shape, v.shape[-1] |
|
BT = min(chunk_size, max(16, triton.next_power_of_2(T))) |
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
|
o = torch.empty_like(v) |
|
def grid(meta): return (triton.cdiv(V, meta['BV']), NT, B * H) |
|
chunk_dplr_fwd_kernel_o[grid]( |
|
qg=qg, |
|
v=v, |
|
v_new=v_new, |
|
A_qk=A_qk, |
|
A_qb=A_qb, |
|
h=h, |
|
o=o, |
|
offsets=offsets, |
|
indices=indices, |
|
T=T, |
|
H=H, |
|
K=K, |
|
V=V, |
|
BT=BT, |
|
HEAD_FIRST=head_first |
|
) |
|
return o |
|
|