# -*- coding: utf-8 -*- # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # Implements argsort based on bitonic sort. # [What is bitonic sort?](https://en.wikipedia.org/wiki/Bitonic_sorter) # Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396 import triton import triton.language as tl from fla.ops.utils.op import log2 @triton.jit def _compare_and_swap( x, ids, flip, i: tl.constexpr, n_dims: tl.constexpr, ): n_outer: tl.constexpr = x.numel >> n_dims shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)] y = tl.reshape(x, shape) # slice left/right with 'stride' 2**(n_dims - i - 1) mask = tl.arange(0, 2)[None, :, None] left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype) right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype) left = tl.reshape(left, x.shape) right = tl.reshape(right, x.shape) # idx y_idx = tl.reshape(ids, shape) left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype) right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype) # actual compare-and-swap idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) ileft = left.to(idtype, bitcast=True) iright = right.to(idtype, bitcast=True) ix = x.to(idtype, bitcast=True) cond = (left > right) != flip ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids)) return ret.to(x.dtype, bitcast=True), new_ids @triton.jit def _bitonic_merge( x, ids, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr, ): n_outer: tl.constexpr = x.numel >> n_dims tl.static_assert(stage <= n_dims) # flip denotes whether to re-arrange sub-sequences of elements in ascending or # descending order. # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage # if flip = 00110011... then all the elements will be re-arranged alternatingly (with # a stride of 2) at this stage if order == 2: shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage] flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) else: flip = order # perform `stage` rounds of `compare-and-swap` for i in tl.static_range(stage): x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) return x, ids @triton.jit def argsort( x, ids, dim: tl.constexpr = None, descending: tl.constexpr = tl.core.CONSTEXPR_0, ): # handle default dimension or check that it is the most minor dim _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported") # iteratively run bitonic merge-sort steps n_dims: tl.constexpr = log2(x.shape[_dim]) for i in tl.static_range(1, n_dims + 1): x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) return x, ids