zaydzuhri's picture
Add files using upload-large-folder tool
f72219a verified
raw
history blame
721 Bytes
# -*- coding: utf-8 -*-
# Copyright (c) 2024, Songlin Yang, Yu Zhang
import os
import triton
import triton.language as tl
import triton.language.extra.libdevice as tldevice
from fla.utils import is_gather_supported
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
div = tldevice.fast_dividef
exp = tldevice.fast_expf
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
@triton.jit
def div_normal(x, y):
return x / y
div = div_normal
exp = tl.exp
log = tl.log
log2 = tl.log2
@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float('-inf')))
if not is_gather_supported:
def gather(*args, **kwargs):
pass
else:
gather = tl.gather