|  | #undef CUB_WRAPPED_NAMESPACE | 
					
						
						|  | #define CUB_WRAPPED_NAMESPACE megablocks | 
					
						
						|  |  | 
					
						
						|  | #include "new_sort.h" | 
					
						
						|  | #include <cstdint> | 
					
						
						|  | #include <cub/cub.cuh> | 
					
						
						|  | #include <c10/cuda/CUDAStream.h> | 
					
						
						|  |  | 
					
						
						|  | #define CUDA_CALL(code)					    \ | 
					
						
						|  | do {                                                      \ | 
					
						
						|  | cudaError_t status = code;                              \ | 
					
						
						|  | std::string err = cudaGetErrorString(status);           \ | 
					
						
						|  | TORCH_CHECK(status == cudaSuccess, err);		    \ | 
					
						
						|  | } while (0) | 
					
						
						|  |  | 
					
						
						|  | namespace megablocks { | 
					
						
						|  |  | 
					
						
						|  | template <typename T> | 
					
						
						|  | void cub_radix_sort(torch::Tensor x, | 
					
						
						|  | int end_bit, | 
					
						
						|  | torch::Tensor x_out, | 
					
						
						|  | torch::Tensor iota_out) { | 
					
						
						|  |  | 
					
						
						|  | torch::Tensor iota = torch::arange(0, x.numel(), x.options()); | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | size_t scratchpad_bytes = 0; | 
					
						
						|  | CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, | 
					
						
						|  | scratchpad_bytes, | 
					
						
						|  | x.data_ptr<T>(), | 
					
						
						|  | x_out.data_ptr<T>(), | 
					
						
						|  | iota.data_ptr<T>(), | 
					
						
						|  | iota_out.data_ptr<T>(), | 
					
						
						|  | x.numel(), | 
					
						
						|  | 0, | 
					
						
						|  | end_bit, | 
					
						
						|  | c10::cuda::getCurrentCUDAStream())); | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | auto options = torch::TensorOptions() | 
					
						
						|  | .dtype(torch::kInt8) | 
					
						
						|  | .device(x.device()); | 
					
						
						|  | torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | CUDA_CALL(cub::DeviceRadixSort::SortPairs(scratchpad.data_ptr(), | 
					
						
						|  | scratchpad_bytes, | 
					
						
						|  | x.data_ptr<T>(), | 
					
						
						|  | x_out.data_ptr<T>(), | 
					
						
						|  | iota.data_ptr<T>(), | 
					
						
						|  | iota_out.data_ptr<T>(), | 
					
						
						|  | x.numel(), | 
					
						
						|  | 0, | 
					
						
						|  | end_bit, | 
					
						
						|  | c10::cuda::getCurrentCUDAStream())); | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | void sort(torch::Tensor x, | 
					
						
						|  | int end_bit, | 
					
						
						|  | torch::Tensor x_out, | 
					
						
						|  | torch::Tensor iota_out) { | 
					
						
						|  | TORCH_CHECK(x.is_cuda()); | 
					
						
						|  | TORCH_CHECK(x.ndimension() == 1); | 
					
						
						|  | TORCH_CHECK(x.scalar_type() == torch::kInt16 || | 
					
						
						|  | x.scalar_type() == torch::kInt32 || | 
					
						
						|  | x.scalar_type() == torch::kInt64); | 
					
						
						|  | TORCH_CHECK(x_out.is_cuda()); | 
					
						
						|  | TORCH_CHECK(x_out.ndimension() == 1); | 
					
						
						|  | TORCH_CHECK(x_out.scalar_type() == x.scalar_type()); | 
					
						
						|  | TORCH_CHECK(iota_out.is_cuda()); | 
					
						
						|  | TORCH_CHECK(iota_out.ndimension() == 1); | 
					
						
						|  | TORCH_CHECK(iota_out.scalar_type() == x.scalar_type()); | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if (x_out.numel() == 0) return; | 
					
						
						|  |  | 
					
						
						|  | switch (x.scalar_type()) { | 
					
						
						|  | case torch::kInt16: | 
					
						
						|  | return cub_radix_sort<short>(x, end_bit, x_out, iota_out); | 
					
						
						|  | case torch::kInt32: | 
					
						
						|  | return cub_radix_sort<int>(x, end_bit, x_out, iota_out); | 
					
						
						|  | } | 
					
						
						|  | TORCH_CHECK(x.scalar_type() == torch::kInt64); | 
					
						
						|  | return cub_radix_sort<long>(x, end_bit, x_out, iota_out); | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | #undef CUDA_CALL | 
					
						
						|  | #undef CUB_WRAPPED_NAMESPACE |