|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.ops.utils.op import exp, log |
|
|
|
|
|
@triton.heuristics({ |
|
'HAS_SCALE': lambda args: args['scale'] is not None |
|
}) |
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({}, num_warps=num_warps) |
|
for num_warps in [1, 2, 4, 8, 16, 32] |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def logsumexp_fwd_kernel( |
|
x, |
|
z, |
|
scale, |
|
D: tl.constexpr, |
|
B: tl.constexpr, |
|
HAS_SCALE: tl.constexpr |
|
): |
|
i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64) |
|
o_d = i_d * B + tl.arange(0, B) |
|
m_d = o_d < D |
|
|
|
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf')) |
|
if HAS_SCALE: |
|
b_x = b_x * scale |
|
b_m = tl.max(b_x, 0) |
|
b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m |
|
tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z) |
|
|
|
|
|
def logsumexp_fwd( |
|
x, |
|
scale: Optional[float] = None, |
|
dtype: Optional[torch.dtype] = None |
|
): |
|
r""" |
|
Compute the logsumexp of the input tensor over the last dimension. |
|
|
|
Args: |
|
x (Tensor): |
|
The input tensor of any shape. |
|
scale (Optional[float]): |
|
The scale applied to the input tensor. Default: `None`. |
|
dtype (Optional[torch.dtype]): |
|
The data type of the output tensor. Default: `None`. |
|
Returns: |
|
Tensor: The logsumexp of the input tensor. |
|
""" |
|
|
|
shape = x.shape |
|
x = x.view(-1, shape[-1]) |
|
N, D = x.shape |
|
B = min(triton.next_power_of_2(D), 64 * 1024) |
|
ND = triton.cdiv(D, B) |
|
|
|
z = x.new_empty(N, ND, dtype=torch.float) |
|
logsumexp_fwd_kernel[(N, ND)]( |
|
x=x, |
|
z=z, |
|
scale=scale, |
|
D=D, |
|
B=B |
|
) |
|
z = z.logsumexp(-1).view(*shape[:-1]) |
|
if dtype is not None and dtype != torch.float: |
|
z = z.to(dtype) |
|
return z |
|
|