zaydzuhri's picture
Add files using upload-large-folder tool
f72219a verified
raw
history blame
14.9 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.ops.utils.op import exp, gather
from fla.utils import is_gather_supported, use_cuda_graph
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for num_warps in [2, 4, 8, 16]
for num_stages in [2, 3, 4]
],
key=['BC', 'K'],
use_cuda_graph=use_cuda_graph,
)
@triton.jit(do_not_specialize=['T'])
def chunk_dplr_fwd_A_kernel_intra_sub_inter(
q,
k,
a,
b,
gi, # cumsum
ge, # before cumsum
Aqk,
Aqb,
Aab,
Aak,
offsets,
indices,
scale: tl.constexpr,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
i_i, i_j = i_c // NC, i_c % NC
if USE_OFFSETS:
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
if i_i <= i_j:
return
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
b_Aqb = tl.zeros([BC, BC], dtype=tl.float32)
b_Aab = tl.zeros([BC, BC], dtype=tl.float32)
b_Aak = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
if HEAD_FIRST:
p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gq_i = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gq_e = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_b = tl.make_block_ptr(b + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gk = tl.make_block_ptr(gi + i_bh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gn = tl.max_contiguous(tl.multiple_of(gi + (i_bh * T + i_t * BT + i_i * BC - 1) * K + o_k, BK), BK)
else:
p_q = tl.make_block_ptr(q + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_a = tl.make_block_ptr(a + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gq_i = tl.make_block_ptr(gi + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_gq_e = tl.make_block_ptr(ge + (bos*H+i_h)*K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_b = tl.make_block_ptr(b + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gk = tl.make_block_ptr(gi + (bos*H+i_h)*K, (K, T), (1, H*K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))
p_gn = gi + (bos + i_t * BT + i_i * BC - 1) * H*K + i_h * K + o_k
# [BK,]
b_gn = tl.load(p_gn, mask=m_k, other=0).to(tl.float32)
# [BC, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_a = tl.load(p_a, boundary_check=(0, 1))
b_gq_i = tl.load(p_gq_i, boundary_check=(0, 1))
b_gq_e = tl.load(p_gq_e, boundary_check=(0, 1))
b_ag = b_a * exp(b_gq_e - b_gn[None, :])
b_qg = b_q * exp(b_gq_i - b_gn[None, :]) * scale
# [BK, BC]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_b = tl.load(p_b, boundary_check=(0, 1))
b_gk = tl.load(p_gk, boundary_check=(0, 1)).to(tl.float32)
tmp = exp(b_gn[:, None] - b_gk)
b_kg = b_k * tmp
b_bg = b_b * tmp
# [BC, BC] using tf32 to improve precision here.
b_Aab += tl.dot(b_ag, b_bg)
b_Aak += tl.dot(b_ag, b_kg)
b_Aqk += tl.dot(b_qg, b_kg)
b_Aqb += tl.dot(b_qg, b_bg)
if HEAD_FIRST:
p_Aqk = tl.make_block_ptr(Aqk + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
p_Aqb = tl.make_block_ptr(Aqb + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
p_Aab = tl.make_block_ptr(Aab + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
p_Aak = tl.make_block_ptr(Aak + i_bh*T*BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
else:
p_Aqk = tl.make_block_ptr(Aqk + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
p_Aqb = tl.make_block_ptr(Aqb + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
p_Aab = tl.make_block_ptr(Aab + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
p_Aak = tl.make_block_ptr(Aak + (bos*H+i_h)*BT, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Aqb, b_Aqb.to(Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Aab, b_Aab.to(Aab.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_Aak, b_Aak.to(Aak.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8, 16, 32]
for num_stages in [2, 3, 4]
],
key=['BK', 'BT'],
use_cuda_graph=use_cuda_graph,
)
@triton.jit(do_not_specialize=['T'])
def chunk_dplr_fwd_A_kernel_intra_sub_intra(
q,
k,
a,
b,
gi,
ge,
qg,
kg,
ag,
bg,
Aqk,
Aqb,
Aab,
Aak,
offsets,
indices,
scale: tl.constexpr,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
USE_OFFSETS: tl.constexpr,
HEAD_FIRST: tl.constexpr,
GATHER_SUPPORTED: tl.constexpr
):
i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
i_j = i_i
if USE_OFFSETS:
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T
last_idx = min((i_t+1) * BT, T) - 1
if HEAD_FIRST:
o_A = i_bh * T*BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC
p_q = tl.make_block_ptr(q + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_a = tl.make_block_ptr(a + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_b = tl.make_block_ptr(b + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_gi = tl.make_block_ptr(gi + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_ge = tl.make_block_ptr(ge + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_g_last = gi + i_bh * T*K + last_idx * K + tl.arange(0, BK)
b_g_last = tl.load(p_g_last, mask=m_k, other=0)
p_qg = tl.make_block_ptr(qg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_kg = tl.make_block_ptr(kg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_ag = tl.make_block_ptr(ag + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_bg = tl.make_block_ptr(bg + i_bh * T*K, (T, K), (K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
else:
o_A = (bos + i_t * BT + i_i * BC + tl.arange(0, BC)) * H*BT + i_h * BT + i_j * BC
p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_a = tl.make_block_ptr(a + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_b = tl.make_block_ptr(b + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_gi = tl.make_block_ptr(gi + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_ge = tl.make_block_ptr(ge + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_g_last = gi + (bos * H + i_h) * K + last_idx * H * K + tl.arange(0, BK)
b_g_last = tl.load(p_g_last, mask=m_k, other=0)
p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_ag = tl.make_block_ptr(ag + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
p_bg = tl.make_block_ptr(bg + (bos * H + i_h) * K, (T, K), (H*K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = b_q * scale
b_k = tl.load(p_k, boundary_check=(0, 1))
b_a = tl.load(p_a, boundary_check=(0, 1))
b_b = tl.load(p_b, boundary_check=(0, 1))
b_gi = tl.load(p_gi, boundary_check=(0, 1)).to(tl.float32)
b_ge = tl.load(p_ge, boundary_check=(0, 1)).to(tl.float32)
# deal with decay term.
g_exp = exp(b_gi)
g_exp_inv = exp(-b_gi + b_g_last[None, :])
b_qg = b_q * g_exp
b_kg = b_k * g_exp_inv
b_bg = b_b * g_exp_inv
b_ag = b_a * exp(b_ge)
tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_bg, b_bg.to(p_bg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_ag, b_ag.to(p_ag.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1))
# tl.debug_barrier()
b_q = b_q.to(b_k.dtype)
# inner attn
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
# a trick to index the j-th row of b_k, b_g, b_b
if GATHER_SUPPORTED:
row_idx = tl.full([1, BK], j, dtype=tl.int16)
# [1, BK]
b_k_j = gather(b_k, row_idx, axis=0)
b_gk_j = gather(b_gi, row_idx, axis=0)
b_b_j = gather(b_b, row_idx, axis=0)
else:
mask = tl.arange(0, BC) == j
b_k_j = tl.sum(tl.where(mask[:, None], b_k, 0), 0)[None, :]
b_gk_j = tl.sum(tl.where(mask[:, None], b_gi, 0), 0)[None, :]
b_b_j = tl.sum(tl.where(mask[:, None], b_b, 0), 0)[None, :]
mask = tl.arange(0, BC) == j
tmp = exp(b_gi - b_gk_j)
b_A_qk = tl.sum(b_q * b_k_j * tmp, 1)
b_A_qk = tl.where(o_i >= j, b_A_qk, 0.)
b_A_qb = tl.sum(b_q * b_b_j * tmp, 1)
b_A_qb = tl.where(o_i >= j, b_A_qb, 0.)
tmp2 = exp(b_ge - b_gk_j)
b_A_ak = tl.sum(b_a * b_k_j * tmp2, 1)
b_A_ak = tl.where(o_i > j, b_A_ak, 0.)
b_A_ab = tl.sum(b_a * b_b_j * tmp2, 1)
b_A_ab = tl.where(o_i > j, b_A_ab, 0.)
tl.store(Aqk + o_A + j, b_A_qk.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
tl.store(Aqb + o_A + j, b_A_qb.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
tl.store(Aab + o_A + j, b_A_ab.to(dtype=Aqb.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
tl.store(Aak + o_A + j, b_A_ak.to(dtype=Aqk.dtype.element_ty, fp_downcast_rounding="rtne"), mask=m_A)
def chunk_fwd_intra_dplr_fn(
q: torch.Tensor,
k: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
gi: torch.Tensor,
ge: torch.Tensor,
scale: float,
chunk_size: int,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None,
head_first: bool = True,
):
if head_first:
B, H, T, K = k.shape
else:
B, T, H, K = k.shape
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
NT = triton.cdiv(T, BT) if offsets is None else len(indices)
BC = min(16, BT)
NC = triton.cdiv(BT, BC)
Aqk = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
Aqb = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=q.dtype)
# involving matrix inverse and it'd be better to use float here.
Aab = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
Aak = q.new_empty(B, *((H, T) if head_first else (T, H)), BT, dtype=torch.float)
grid = (NT, NC * NC, B * H)
chunk_dplr_fwd_A_kernel_intra_sub_inter[grid](
q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
offsets=offsets, indices=indices,
scale=scale,
T=T, H=H, K=K, BT=BT, BC=BC, NC=NC,
HEAD_FIRST=head_first
)
grid = (NT, NC, B * H)
BK = triton.next_power_of_2(K)
qg = torch.empty_like(q)
kg = torch.empty_like(k, dtype=q.dtype)
ag = torch.empty_like(a, dtype=q.dtype)
bg = torch.empty_like(b, dtype=q.dtype)
chunk_dplr_fwd_A_kernel_intra_sub_intra[grid](
q=q, k=k, a=a, b=b, gi=gi, ge=ge, Aqk=Aqk, Aqb=Aqb, Aab=Aab, Aak=Aak,
qg=qg, kg=kg, ag=ag, bg=bg,
offsets=offsets, indices=indices,
scale=scale,
T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, HEAD_FIRST=head_first, NC=NC,
GATHER_SUPPORTED=is_gather_supported
)
return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg