# -*- coding: utf-8 -*- | |
# Copyright (c) 2024, Songlin Yang, Yu Zhang | |
import os | |
import triton | |
import triton.language as tl | |
import triton.language.extra.libdevice as tldevice | |
from fla.utils import is_gather_supported | |
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': | |
div = tldevice.fast_dividef | |
exp = tldevice.fast_expf | |
log = tldevice.fast_logf | |
log2 = tldevice.fast_log2f | |
else: | |
def div_normal(x, y): | |
return x / y | |
div = div_normal | |
exp = tl.exp | |
log = tl.log | |
log2 = tl.log2 | |
def safe_exp(x): | |
return exp(tl.where(x <= 0, x, float('-inf'))) | |
if not is_gather_supported: | |
def gather(*args, **kwargs): | |
pass | |
else: | |
gather = tl.gather | |