|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.utils import check_shared_mem, input_guard |
|
|
|
BS_LIST = [32, 64] if check_shared_mem() else [16, 32] |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=num_warps) |
|
for num_warps in [1, 2, 4, 8] |
|
], |
|
key=['BT'] |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_local_cumsum_scalar_kernel( |
|
s, |
|
o, |
|
offsets, |
|
indices, |
|
T, |
|
H: tl.constexpr, |
|
BT: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
REVERSE: tl.constexpr |
|
): |
|
i_t, i_bh = tl.program_id(0), tl.program_id(1) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
if USE_OFFSETS: |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
else: |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
if HEAD_FIRST: |
|
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
|
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
|
else: |
|
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
|
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
|
|
|
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) |
|
b_o = tl.cumsum(b_s, axis=0) |
|
if REVERSE: |
|
b_z = tl.sum(b_s, axis=0) |
|
b_o = -b_o + b_z[None] + b_s |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BS': BS}, num_warps=num_warps) |
|
for BS in BS_LIST |
|
for num_warps in [2, 4, 8] |
|
], |
|
key=['S', 'BT'], |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_local_cumsum_vector_kernel( |
|
s, |
|
o, |
|
offsets, |
|
indices, |
|
T, |
|
H: tl.constexpr, |
|
S: tl.constexpr, |
|
BT: tl.constexpr, |
|
BS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
REVERSE: tl.constexpr |
|
): |
|
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
if USE_OFFSETS: |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
else: |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
o_i = tl.arange(0, BT) |
|
if REVERSE: |
|
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) |
|
else: |
|
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) |
|
|
|
if HEAD_FIRST: |
|
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
p_o = tl.make_block_ptr(o + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
else: |
|
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
|
|
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) |
|
b_o = tl.dot(m_s, b_s, allow_tf32=False) |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BT': 16}, num_warps=2), |
|
triton.Config({'BT': 32}, num_warps=4), |
|
triton.Config({'BT': 32}, num_warps=2), |
|
triton.Config({'BT': 64}, num_warps=8), |
|
triton.Config({'BT': 64}, num_warps=4), |
|
], |
|
key=[] |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_global_cumsum_scalar_kernel( |
|
s, |
|
o, |
|
offsets, |
|
T, |
|
H: tl.constexpr, |
|
BT: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
REVERSE: tl.constexpr |
|
): |
|
i_bh = tl.program_id(0) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
if USE_OFFSETS: |
|
bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) |
|
else: |
|
bos, eos = i_b * T, i_b * T + T |
|
T = eos - bos |
|
|
|
b_z = tl.zeros([], dtype=tl.float32) |
|
NT = tl.cdiv(T, BT) |
|
for i_c in range(NT): |
|
i_t = NT-1-i_c if REVERSE else i_c |
|
if HEAD_FIRST: |
|
p_s = tl.make_block_ptr(s + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
|
p_o = tl.make_block_ptr(o + bos*H + i_h*T, (T,), (1,), (i_t * BT,), (BT,), (0,)) |
|
else: |
|
p_s = tl.make_block_ptr(s + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
|
p_o = tl.make_block_ptr(o + bos*H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) |
|
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) |
|
b_o = tl.cumsum(b_s, axis=0) |
|
b_ss = tl.sum(b_s, 0) |
|
if REVERSE: |
|
b_o = -b_o + b_ss + b_s |
|
b_o += b_z |
|
if i_c >= 0: |
|
b_z += b_ss |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None, |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BT': BT}, num_warps=num_warps) |
|
for BT in [16, 32, 64] |
|
for num_warps in [2, 4, 8] |
|
], |
|
key=['S'] |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def chunk_global_cumsum_vector_kernel( |
|
s, |
|
z, |
|
offsets, |
|
T, |
|
H: tl.constexpr, |
|
S: tl.constexpr, |
|
BT: tl.constexpr, |
|
BS: tl.constexpr, |
|
HEAD_FIRST: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr, |
|
REVERSE: tl.constexpr |
|
): |
|
i_s, i_bh = tl.program_id(0), tl.program_id(1) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
if USE_OFFSETS: |
|
bos, eos = tl.load(offsets + i_b).to(tl.int32), tl.load(offsets + i_b + 1).to(tl.int32) |
|
else: |
|
bos, eos = i_b * T, i_b * T + T |
|
T = eos - bos |
|
|
|
o_i = tl.arange(0, BT) |
|
if REVERSE: |
|
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.) |
|
else: |
|
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.) |
|
|
|
b_z = tl.zeros([BS], dtype=tl.float32) |
|
NT = tl.cdiv(T, BT) |
|
for i_c in range(NT): |
|
i_t = NT-1-i_c if REVERSE else i_c |
|
if HEAD_FIRST: |
|
p_s = tl.make_block_ptr(s + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
p_z = tl.make_block_ptr(z + (bos * H + i_h*T)*S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
else: |
|
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
p_z = tl.make_block_ptr(z + (bos * H + i_h) * S, (T, S), (H*S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) |
|
|
|
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) |
|
b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False) |
|
tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1)) |
|
if i_c >= 0: |
|
b_z += tl.sum(b_s, 0) |
|
|
|
|
|
def chunk_local_cumsum_scalar( |
|
g: torch.Tensor, |
|
chunk_size: int, |
|
reverse: bool = False, |
|
offsets: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
head_first: bool = True, |
|
output_dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
if head_first: |
|
B, H, T = g.shape |
|
else: |
|
B, T, H = g.shape |
|
if offsets is not None: |
|
B = len(offsets) - 1 |
|
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" |
|
BT = chunk_size |
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) |
|
grid = (NT, B * H) |
|
chunk_local_cumsum_scalar_kernel[grid]( |
|
g_org, |
|
g, |
|
offsets, |
|
indices, |
|
T=T, |
|
H=H, |
|
BT=BT, |
|
HEAD_FIRST=head_first, |
|
REVERSE=reverse |
|
) |
|
return g |
|
|
|
|
|
def chunk_local_cumsum_vector( |
|
g: torch.Tensor, |
|
chunk_size: int, |
|
reverse: bool = False, |
|
offsets: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
head_first: bool = True, |
|
output_dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
if head_first: |
|
B, H, T, S = g.shape |
|
else: |
|
B, T, H, S = g.shape |
|
BT = chunk_size |
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
assert chunk_size == 2**(chunk_size.bit_length()-1), "chunk_size must be a power of 2" |
|
|
|
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) |
|
def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) |
|
|
|
|
|
|
|
chunk_local_cumsum_vector_kernel[grid]( |
|
g_org, |
|
g, |
|
offsets, |
|
indices, |
|
T=T, |
|
H=H, |
|
S=S, |
|
BT=BT, |
|
HEAD_FIRST=head_first, |
|
REVERSE=reverse |
|
) |
|
return g |
|
|
|
|
|
@input_guard |
|
def chunk_global_cumsum_scalar( |
|
s: torch.Tensor, |
|
dtype: Optional[torch.dtype] = None, |
|
reverse: bool = False, |
|
offsets: Optional[torch.Tensor] = None, |
|
head_first: bool = True, |
|
output_dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
dtype = dtype or s.dtype |
|
if head_first: |
|
B, H, T = s.shape |
|
else: |
|
B, T, H = s.shape |
|
if offsets is not None: |
|
B = len(offsets) - 1 |
|
grid = (B * H,) |
|
z = torch.empty_like(s, dtype=output_dtype or dtype) |
|
chunk_global_cumsum_scalar_kernel[grid]( |
|
s, |
|
z, |
|
offsets, |
|
T=T, |
|
H=H, |
|
HEAD_FIRST=head_first, |
|
REVERSE=reverse |
|
) |
|
return z |
|
|
|
|
|
@input_guard |
|
def chunk_global_cumsum_vector( |
|
s: torch.Tensor, |
|
dtype: Optional[torch.dtype] = None, |
|
reverse: bool = False, |
|
offsets: Optional[torch.Tensor] = None, |
|
head_first: bool = True, |
|
output_dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
dtype = dtype or s.dtype |
|
if head_first: |
|
B, H, T, S = s.shape |
|
else: |
|
B, T, H, S = s.shape |
|
BS = min(32, triton.next_power_of_2(S)) |
|
if offsets is not None: |
|
B = len(offsets) - 1 |
|
grid = (triton.cdiv(S, BS), B * H) |
|
z = torch.empty_like(s, dtype=output_dtype or dtype) |
|
chunk_global_cumsum_vector_kernel[grid]( |
|
s, |
|
z, |
|
offsets, |
|
T=T, |
|
H=H, |
|
S=S, |
|
BS=BS, |
|
HEAD_FIRST=head_first, |
|
REVERSE=reverse |
|
) |
|
return z |
|
|
|
|
|
@input_guard |
|
def chunk_global_cumsum( |
|
s: torch.Tensor, |
|
dtype: Optional[torch.dtype] = None, |
|
reverse: bool = False, |
|
offsets: Optional[torch.Tensor] = None, |
|
head_first: bool = True, |
|
output_dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
if offsets is not None: |
|
assert s.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" |
|
if len(s.shape) == 3: |
|
return chunk_global_cumsum_scalar(s, dtype, reverse, offsets, head_first, output_dtype) |
|
elif len(s.shape) == 4: |
|
return chunk_global_cumsum_vector(s, dtype, reverse, offsets, head_first, output_dtype) |
|
else: |
|
raise ValueError(f"Unsupported input shape {s.shape}. " |
|
f"which should be [B, H, T]/[B, H, T, D] if `head_first=True` " |
|
f"or [B, T, H]/[B, T, H, D] otherwise") |
|
|
|
|
|
@input_guard |
|
def chunk_local_cumsum( |
|
g: torch.Tensor, |
|
chunk_size: int, |
|
reverse: bool = False, |
|
offsets: Optional[torch.Tensor] = None, |
|
indices: Optional[torch.Tensor] = None, |
|
head_first: bool = True, |
|
output_dtype: Optional[torch.dtype] = torch.float |
|
) -> torch.Tensor: |
|
if offsets is not None: |
|
assert g.shape[0] == 1, "Only batch size 1 is supported when offsets are provided" |
|
if len(g.shape) == 3: |
|
return chunk_local_cumsum_scalar(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) |
|
elif len(g.shape) == 4: |
|
return chunk_local_cumsum_vector(g, chunk_size, reverse, offsets, indices, head_first, output_dtype) |
|
else: |
|
raise ValueError(f"Unsupported input shape {g.shape}. " |
|
f"which should be (B, H, T, dim) if `head_first=True` " |
|
f"or (batch_size, num_heads, seq_len) otherwise") |
|
|