zaydzuhri's picture
Add files using upload-large-folder tool
f72219a verified
raw
history blame
1.9 kB
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2024, Songlin Yang, Yu Zhang
from typing import Optional
import torch
import triton
import triton.language as tl
from fla.ops.utils.op import exp, log
@triton.heuristics({
'HAS_SCALE': lambda args: args['scale'] is not None
})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16, 32]
],
key=['D']
)
@triton.jit
def logsumexp_fwd_kernel(
x,
z,
scale,
D: tl.constexpr,
B: tl.constexpr,
HAS_SCALE: tl.constexpr
):
i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)
o_d = i_d * B + tl.arange(0, B)
m_d = o_d < D
b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))
if HAS_SCALE:
b_x = b_x * scale
b_m = tl.max(b_x, 0)
b_z = log(tl.sum(exp(b_x - b_m), 0)) + b_m
tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z)
def logsumexp_fwd(
x,
scale: Optional[float] = None,
dtype: Optional[torch.dtype] = None
):
r"""
Compute the logsumexp of the input tensor over the last dimension.
Args:
x (Tensor):
The input tensor of any shape.
scale (Optional[float]):
The scale applied to the input tensor. Default: `None`.
dtype (Optional[torch.dtype]):
The data type of the output tensor. Default: `None`.
Returns:
Tensor: The logsumexp of the input tensor.
"""
shape = x.shape
x = x.view(-1, shape[-1])
N, D = x.shape
B = min(triton.next_power_of_2(D), 64 * 1024)
ND = triton.cdiv(D, B)
z = x.new_empty(N, ND, dtype=torch.float)
logsumexp_fwd_kernel[(N, ND)](
x=x,
z=z,
scale=scale,
D=D,
B=B
)
z = z.logsumexp(-1).view(*shape[:-1])
if dtype is not None and dtype != torch.float:
z = z.to(dtype)
return z