zaydzuhri's picture
Add files using upload-large-folder tool
f72219a verified
raw
history blame
12.5 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.utils import check_shared_mem, input_guard
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8]
],
key=['BT']
)
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_scalar_kernel(
s,
o,
offsets,
indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
HEAD_FIRST: tl.constexpr,
USE_OFFSETS: tl.constexpr,
REVERSE: tl.constexpr
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
# [BT]
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BS': BS}, num_warps=num_warps)
for BS in BS_LIST
for num_warps in [2, 4, 8]
],
key=['S', 'BT'],
)
@triton.jit(do_not_specialize=['T'])
def chunk_local_cumsum_vector_kernel(
s,
o,
offsets,
indices,
T,
H: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
HEAD_FIRST: tl.constexpr,
USE_OFFSETS: tl.constexpr,
REVERSE: tl.constexpr
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, BT)
if REVERSE:
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
else:
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
else:
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BT': 16}, num_warps=2),
triton.Config({'BT': 32}, num_warps=4),
triton.Config({'BT': 32}, num_warps=2),
triton.Config({'BT': 64}, num_warps=8),
triton.Config({'BT': 64}, num_warps=4),
],
key=[]
)
@triton.jit(do_not_specialize=['T'])
def chunk_global_cumsum_scalar_kernel(
s,
o,
offsets,
T,
H: tl.constexpr,
BT: tl.constexpr,
HEAD_FIRST: tl.constexpr,
USE_OFFSETS: tl.constexpr,
REVERSE: tl.constexpr
):
i_bh = tl.program_id(0)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32)
else:
bos, eos = i_b * T, i_b * T + T
T = eos - bos
b_z = tl.zeros([], dtype=tl.float32)
NT = tl.cdiv(T, BT)
for i_c in range(NT):
i_t = NT-1-i_c if REVERSE else i_c
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,))
else:
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
b_o = tl.cumsum(b_s, axis=0)
b_ss = tl.sum(b_s, 0)
if REVERSE:
b_o = -b_o + b_ss + b_s
b_o += b_z
if i_c >= 0:
b_z += b_ss
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
})
@triton.autotune(
configs=[
triton.Config({'BT': BT}, num_warps=num_warps)
for BT in [16, 32, 64]
for num_warps in [2, 4, 8]
],
key=['S']
)
@triton.jit(do_not_specialize=['T'])
def chunk_global_cumsum_vector_kernel(
s,
z,
offsets,
T,
H: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
HEAD_FIRST: tl.constexpr,
USE_OFFSETS: tl.constexpr,
REVERSE: tl.constexpr
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32)
else:
bos, eos = i_b * T, i_b * T + T
T = eos - bos
o_i = tl.arange(0, BT)
if REVERSE:
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
else:
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
b_z = tl.zeros([BS], dtype=tl.float32)
NT = tl.cdiv(T, BT)
for i_c in range(NT):
i_t = NT-1-i_c if REVERSE else i_c
if HEAD_FIRST:
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
else:
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))
if i_c >= 0:
b_z += tl.sum(b_s, 0)
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
offsets: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
head_first: bool = True,
output_dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
if offsets is not None:
B = len(offsets) - 1
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
BT = chunk_size
NT = triton.cdiv(T, BT) if offsets is None else len(indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
chunk_local_cumsum_scalar_kernel[grid](
g_org,
g,
offsets,
indices,
T=T,
H=H,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse
)
return g
def chunk_local_cumsum_vector(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
offsets: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
head_first: bool = True,
output_dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
NT = triton.cdiv(T, BT) if offsets is None else len(indices)
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2"
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel[grid](
g_org,
g,
offsets,
indices,
T=T,
H=H,
S=S,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse
)
return g
@input_guard
def chunk_global_cumsum_scalar(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
reverse: bool = False,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
output_dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
dtype = dtype or s.dtype
if head_first:
B, H, T = s.shape
else:
B, T, H = s.shape
if offsets is not None:
B = len(offsets) - 1
grid = (B * H,)
z = torch.empty_like(s, dtype=output_dtype or dtype)
chunk_global_cumsum_scalar_kernel[grid](
s,
z,
offsets,
T=T,
H=H,
HEAD_FIRST=head_first,
REVERSE=reverse
)
return z
@input_guard
def chunk_global_cumsum_vector(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
reverse: bool = False,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
output_dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
dtype = dtype or s.dtype
if head_first:
B, H, T, S = s.shape
else:
B, T, H, S = s.shape
BS = min(32, triton.next_power_of_2(S))
if offsets is not None:
B = len(offsets) - 1
grid = (triton.cdiv(S, BS), B * H)
z = torch.empty_like(s, dtype=output_dtype or dtype)
chunk_global_cumsum_vector_kernel[grid](
s,
z,
offsets,
T=T,
H=H,
S=S,
BS=BS,
HEAD_FIRST=head_first,
REVERSE=reverse
)
return z
@input_guard
def chunk_global_cumsum(
s: torch.Tensor,
dtype: Optional[torch.dtype] = None,
reverse: bool = False,
offsets: Optional[torch.Tensor] = None,
head_first: bool = True,
output_dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
if offsets is not None:
assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided"
if len(s.shape) == 3:
return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype)
elif len(s.shape) == 4:
return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype)
else:
raise ValueError(f"Unsupported input shape {s.shape}. "
f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` "
f"or [B, T, H]/[B, T, H, D] otherwise")
@input_guard
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
offsets: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
head_first: bool = True,
output_dtype: Optional[torch.dtype] = torch.float
) -> torch.Tensor:
if offsets is not None:
assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided"
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype)
else:
raise ValueError(f"Unsupported input shape {g.shape}. "
f"which should be (B, H, T, dim) if `head_first=True` "
f"or (batch_size, num_heads, seq_len) otherwise")