zaydzuhri's picture
Add files using upload-large-folder tool
f72219a verified
raw
history blame
6.41 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_chunk_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BD': BD}, num_warps=num_warps)
for BD in [16, 32, 64, 128]
for num_warps in [1, 2, 4, 8]
],
key=['BT']
)
@triton.jit(do_not_specialize=['T'])
def mean_pooling_fwd_kernel(
x,
o,
offsets,
indices,
T: tl.constexpr,
H: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
NT: tl.constexpr,
USE_OFFSETS: tl.constexpr
):
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_tg = i_t
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
# [BT, BD]
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
# [BD]
b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None
})
@triton.autotune(
configs=[
triton.Config({'BD': BD}, num_warps=num_warps)
for BD in [16, 32, 64, 128]
for num_warps in [1, 2, 4, 8]
],
key=['BT']
)
@triton.jit(do_not_specialize=['T'])
def mean_pooling_bwd_kernel(
do,
dx,
offsets,
indices,
T: tl.constexpr,
H: tl.constexpr,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
NT: tl.constexpr,
USE_OFFSETS: tl.constexpr
):
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_tg = i_t
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0))
p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,))
# [BD]
b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32)
# [BT, BD]
b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None]
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1))
def mean_pooling_fwd(
x: torch.Tensor,
chunk_size: int,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None
) -> torch.Tensor:
B, T, H, D = x.shape
BT = chunk_size
NT = triton.cdiv(T, BT) if offsets is None else len(indices)
o = x.new_empty(B, NT, H, D)
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
mean_pooling_fwd_kernel[grid](
x,
o,
offsets,
indices,
T=T,
H=H,
D=D,
BT=BT,
NT=NT,
)
return o
def mean_pooling_bwd(
do: torch.Tensor,
batch_size: int,
seq_len: int,
chunk_size: int,
offsets: Optional[torch.LongTensor] = None,
indices: Optional[torch.LongTensor] = None
) -> torch.Tensor:
B, T, H, D = batch_size, seq_len, *do.shape[-2:]
BT = chunk_size
NT = triton.cdiv(T, BT) if offsets is None else len(indices)
dx = do.new_empty(B, T, H, D)
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H)
mean_pooling_bwd_kernel[grid](
do,
dx,
offsets,
indices,
T=T,
H=H,
D=D,
BT=BT,
NT=NT
)
return dx
class MeanPoolingFunction(torch.autograd.Function):
@staticmethod
@input_guard
@autocast_custom_fwd
def forward(
ctx,
x: torch.Tensor,
chunk_size: int,
offsets: Optional[torch.LongTensor] = None
) -> torch.Tensor:
# 2-d indices denoting the offsets of chunks in each sequence
# for example, if the passed `offsets` is [0, 100, 356] and `chunk_size` is 64,
# then there are 2 and 4 chunks in the 1st and 2nd sequences respectively, and `indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None
o = mean_pooling_fwd(x, chunk_size, offsets, indices)
ctx.batch_size = x.shape[0]
ctx.seq_len = x.shape[1]
ctx.chunk_size = chunk_size
ctx.offsets = offsets
ctx.indices = indices
return o
@staticmethod
@input_guard
@autocast_custom_bwd
def backward(
ctx, do
) -> Tuple[torch.Tensor, None, None]:
batch_size = ctx.batch_size
seq_len = ctx.seq_len
chunk_size = ctx.chunk_size
offsets = ctx.offsets
indices = ctx.indices
dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, offsets, indices)
return dx, None, None
def mean_pooling(
x: torch.Tensor,
chunk_size: int,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False
) -> torch.Tensor:
if head_first:
x = x.transpose(1, 2)
if cu_seqlens is not None:
if x.shape[0] != 1:
raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing.")
o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens)
if head_first:
o = o.transpose(1, 2)
return o