|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.ops.common.utils import prepare_chunk_indices |
|
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BD': BD}, num_warps=num_warps) |
|
for BD in [16, 32, 64, 128] |
|
for num_warps in [1, 2, 4, 8] |
|
], |
|
key=['BT'] |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def mean_pooling_fwd_kernel( |
|
x, |
|
o, |
|
offsets, |
|
indices, |
|
T: tl.constexpr, |
|
H: tl.constexpr, |
|
D: tl.constexpr, |
|
BT: tl.constexpr, |
|
BD: tl.constexpr, |
|
NT: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr |
|
): |
|
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
if USE_OFFSETS: |
|
i_tg = i_t |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
NT = tl.cdiv(T, BT) |
|
else: |
|
NT = tl.cdiv(T, BT) |
|
i_tg = i_b * NT + i_t |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
p_x = tl.make_block_ptr(x + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
p_o = tl.make_block_ptr(o + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) |
|
|
|
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) |
|
|
|
b_o = tl.sum(b_x, axis=0) / min(BT, T - i_t * BT) |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) |
|
|
|
|
|
@triton.heuristics({ |
|
'USE_OFFSETS': lambda args: args['offsets'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BD': BD}, num_warps=num_warps) |
|
for BD in [16, 32, 64, 128] |
|
for num_warps in [1, 2, 4, 8] |
|
], |
|
key=['BT'] |
|
) |
|
@triton.jit(do_not_specialize=['T']) |
|
def mean_pooling_bwd_kernel( |
|
do, |
|
dx, |
|
offsets, |
|
indices, |
|
T: tl.constexpr, |
|
H: tl.constexpr, |
|
D: tl.constexpr, |
|
BT: tl.constexpr, |
|
BD: tl.constexpr, |
|
NT: tl.constexpr, |
|
USE_OFFSETS: tl.constexpr |
|
): |
|
i_d, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
i_b, i_h = i_bh // H, i_bh % H |
|
if USE_OFFSETS: |
|
i_tg = i_t |
|
i_n, i_t = tl.load(indices + i_t * 2).to(tl.int32), tl.load(indices + i_t * 2 + 1).to(tl.int32) |
|
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32) |
|
T = eos - bos |
|
NT = tl.cdiv(T, BT) |
|
else: |
|
NT = tl.cdiv(T, BT) |
|
i_tg = i_b * NT + i_t |
|
bos, eos = i_b * T, i_b * T + T |
|
|
|
p_dx = tl.make_block_ptr(dx + (bos * H + i_h) * D, (T, D), (H*D, 1), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
p_do = tl.make_block_ptr(do + (i_tg * H + i_h) * D, (D,), (1,), (i_d * BD,), (BD,), (0,)) |
|
|
|
b_do = tl.load(p_do, boundary_check=(0,)).to(tl.float32) |
|
|
|
b_dx = b_do / tl.full((BT,), min(BT, T - i_t * BT), dtype=tl.float32)[:, None] |
|
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
def mean_pooling_fwd( |
|
x: torch.Tensor, |
|
chunk_size: int, |
|
offsets: Optional[torch.LongTensor] = None, |
|
indices: Optional[torch.LongTensor] = None |
|
) -> torch.Tensor: |
|
B, T, H, D = x.shape |
|
BT = chunk_size |
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
|
o = x.new_empty(B, NT, H, D) |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) |
|
mean_pooling_fwd_kernel[grid]( |
|
x, |
|
o, |
|
offsets, |
|
indices, |
|
T=T, |
|
H=H, |
|
D=D, |
|
BT=BT, |
|
NT=NT, |
|
) |
|
return o |
|
|
|
|
|
def mean_pooling_bwd( |
|
do: torch.Tensor, |
|
batch_size: int, |
|
seq_len: int, |
|
chunk_size: int, |
|
offsets: Optional[torch.LongTensor] = None, |
|
indices: Optional[torch.LongTensor] = None |
|
) -> torch.Tensor: |
|
B, T, H, D = batch_size, seq_len, *do.shape[-2:] |
|
BT = chunk_size |
|
NT = triton.cdiv(T, BT) if offsets is None else len(indices) |
|
|
|
dx = do.new_empty(B, T, H, D) |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), NT, B * H) |
|
mean_pooling_bwd_kernel[grid]( |
|
do, |
|
dx, |
|
offsets, |
|
indices, |
|
T=T, |
|
H=H, |
|
D=D, |
|
BT=BT, |
|
NT=NT |
|
) |
|
return dx |
|
|
|
|
|
class MeanPoolingFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@input_guard |
|
@autocast_custom_fwd |
|
def forward( |
|
ctx, |
|
x: torch.Tensor, |
|
chunk_size: int, |
|
offsets: Optional[torch.LongTensor] = None |
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
|
indices = prepare_chunk_indices(offsets, chunk_size) if offsets is not None else None |
|
o = mean_pooling_fwd(x, chunk_size, offsets, indices) |
|
ctx.batch_size = x.shape[0] |
|
ctx.seq_len = x.shape[1] |
|
ctx.chunk_size = chunk_size |
|
ctx.offsets = offsets |
|
ctx.indices = indices |
|
return o |
|
|
|
@staticmethod |
|
@input_guard |
|
@autocast_custom_bwd |
|
def backward( |
|
ctx, do |
|
) -> Tuple[torch.Tensor, None, None]: |
|
batch_size = ctx.batch_size |
|
seq_len = ctx.seq_len |
|
chunk_size = ctx.chunk_size |
|
offsets = ctx.offsets |
|
indices = ctx.indices |
|
dx = mean_pooling_bwd(do, batch_size, seq_len, chunk_size, offsets, indices) |
|
return dx, None, None |
|
|
|
|
|
def mean_pooling( |
|
x: torch.Tensor, |
|
chunk_size: int, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = False |
|
) -> torch.Tensor: |
|
if head_first: |
|
x = x.transpose(1, 2) |
|
if cu_seqlens is not None: |
|
if x.shape[0] != 1: |
|
raise ValueError(f"The batch size is expected to be 1 rather than {x.shape[0]} when using `cu_seqlens`." |
|
f"Please flatten variable-length inputs before processing.") |
|
o = MeanPoolingFunction.apply(x, chunk_size, cu_seqlens) |
|
if head_first: |
|
o = o.transpose(1, 2) |
|
return o |
|
|