Upload custom kernels
Browse files- bitnet_kernel/bitnet_kernels.cu +72 -0
- bitnet_kernel/bitnet_kernels.h +83 -0
- build.toml +15 -0
- flake.nix +13 -0
- torch-ext/bitnet_kernel/__init__.py +4 -0
- torch-ext/torch_binding.cpp +11 -0
- torch-ext/torch_binding.h +5 -0
bitnet_kernel/bitnet_kernels.cu
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "bitnet_kernels.h"
|
2 |
+
#include <torch/all.h>
|
3 |
+
|
4 |
+
extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
|
5 |
+
if (M == 1 && N == 3840 && K == 2560){
|
6 |
+
ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
7 |
+
}
|
8 |
+
else if (M == 1 && N == 2560 && K == 2560){
|
9 |
+
ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
10 |
+
}
|
11 |
+
else if (M == 1 && N == 13824 && K == 2560){
|
12 |
+
ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
13 |
+
}
|
14 |
+
else if (M == 1 && N == 2560 && K == 6912){
|
15 |
+
ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
16 |
+
}
|
17 |
+
else if(M == 1 && N == 4800 && K == 3200){
|
18 |
+
ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
19 |
+
}
|
20 |
+
else if(M == 1 && N == 3200 && K == 3200){
|
21 |
+
ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
22 |
+
}
|
23 |
+
else if(M == 1 && N == 20480 && K == 3200){
|
24 |
+
ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
25 |
+
}
|
26 |
+
else if(M == 1 && N == 3200 && K == 10240){
|
27 |
+
ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
28 |
+
}
|
29 |
+
else if(M == 1 && N == 5120 && K == 27648){
|
30 |
+
ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
31 |
+
}
|
32 |
+
else if(M == 1 && N == 55296 && K == 5120){
|
33 |
+
ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
|
34 |
+
}
|
35 |
+
else{
|
36 |
+
std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
torch::Tensor bitlinear_int8xint2_cpp(torch::Tensor input0, torch::Tensor input1, torch::Tensor s, torch::Tensor ws) {
|
41 |
+
// Get input dimensions
|
42 |
+
auto out_shape = input0.sizes().vec();
|
43 |
+
out_shape.back() = input1.size(0);
|
44 |
+
|
45 |
+
// Calculate M, N, K
|
46 |
+
int M = input0.size(0);
|
47 |
+
if (out_shape.size() == 3) {
|
48 |
+
M *= input0.size(1);
|
49 |
+
}
|
50 |
+
int N = input1.size(0);
|
51 |
+
int K = input1.size(1) * 4;
|
52 |
+
|
53 |
+
// Create output tensor
|
54 |
+
auto output = torch::zeros(out_shape, torch::dtype(torch::kBFloat16).device(input0.device()));
|
55 |
+
|
56 |
+
// Get CUDA stream
|
57 |
+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
58 |
+
|
59 |
+
// Call kernel
|
60 |
+
bitlinear_int8xint2(
|
61 |
+
reinterpret_cast<int8_t*>(input0.data_ptr()),
|
62 |
+
reinterpret_cast<int8_t*>(input1.data_ptr()),
|
63 |
+
reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
|
64 |
+
reinterpret_cast<__nv_bfloat16*>(s.data_ptr()),
|
65 |
+
reinterpret_cast<__nv_bfloat16*>(ws.data_ptr()),
|
66 |
+
M, N, K, stream
|
67 |
+
);
|
68 |
+
|
69 |
+
return output;
|
70 |
+
}
|
71 |
+
|
72 |
+
|
bitnet_kernel/bitnet_kernels.h
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda_runtime.h>
|
2 |
+
#include <math_constants.h>
|
3 |
+
#include <math.h>
|
4 |
+
#include <mma.h>
|
5 |
+
#include <iostream>
|
6 |
+
#include <cuda.h>
|
7 |
+
#include <cuda_fp16.h>
|
8 |
+
#include <cuda_bf16.h>
|
9 |
+
|
10 |
+
|
11 |
+
#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
|
12 |
+
#define TVM_ENABLE_L2_PREFETCH 1
|
13 |
+
#else
|
14 |
+
#define TVM_ENABLE_L2_PREFETCH 0
|
15 |
+
#endif
|
16 |
+
|
17 |
+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
|
18 |
+
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
|
19 |
+
#else
|
20 |
+
#define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
|
21 |
+
#endif
|
22 |
+
|
23 |
+
template <typename T1, typename T2>
|
24 |
+
__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
|
25 |
+
{
|
26 |
+
// convert 8 int2b_t to 8 int8b_t -> 2 int32
|
27 |
+
uint *i8s = reinterpret_cast<uint *>(_i8s);
|
28 |
+
|
29 |
+
// i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
|
30 |
+
uint const i2s = *_i2s;
|
31 |
+
|
32 |
+
static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
|
33 |
+
static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
|
34 |
+
static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;
|
35 |
+
|
36 |
+
#pragma unroll
|
37 |
+
for (int i = 0; i < (N / 4); i++)
|
38 |
+
{
|
39 |
+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
40 |
+
: "=r"(i8s[i])
|
41 |
+
: "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
|
42 |
+
i8s[i] = __vsubss4(i8s[i], 0x02020202);
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
|
47 |
+
__global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
|
48 |
+
constexpr int K_per_loop = 16;
|
49 |
+
constexpr int wmma_K = 32;
|
50 |
+
constexpr int wmma_N = 16;
|
51 |
+
int in_thread_C_local[1];
|
52 |
+
signed char A_local[K_per_loop];
|
53 |
+
int B_reshape_local[1];
|
54 |
+
signed char B_decode_local[K_per_loop];
|
55 |
+
int red_buf0[1];
|
56 |
+
in_thread_C_local[0] = 0;
|
57 |
+
#pragma unroll
|
58 |
+
for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
|
59 |
+
*(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
|
60 |
+
B_reshape_local[0] = *(int*)(B +
|
61 |
+
(((int)blockIdx.x) * N_block_size * K / 4) +
|
62 |
+
(k_0 * K_block_size * K_per_loop * wmma_N / 4) +
|
63 |
+
((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
|
64 |
+
((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
|
65 |
+
((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
|
66 |
+
((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
|
67 |
+
);
|
68 |
+
decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
|
69 |
+
#pragma unroll
|
70 |
+
for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
|
71 |
+
in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
|
72 |
+
}
|
73 |
+
}
|
74 |
+
red_buf0[0] = in_thread_C_local[0];
|
75 |
+
#pragma unroll
|
76 |
+
for (int offset = K_block_size/2; offset > 0; offset /= 2) {
|
77 |
+
red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
|
78 |
+
}
|
79 |
+
int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
|
80 |
+
int ws_idx = out_idx / (N / ws_num);
|
81 |
+
if (threadIdx.x == 0)
|
82 |
+
dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
|
83 |
+
}
|
build.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "bitnet_kernel"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
src = [
|
6 |
+
"torch-ext/torch_binding.cpp",
|
7 |
+
"torch-ext/torch_binding.h"
|
8 |
+
]
|
9 |
+
|
10 |
+
[kernel.rmsnorm_kernel]
|
11 |
+
src = [
|
12 |
+
"bitnet_kernel/bitnet_kernel.cu",
|
13 |
+
"bitnet_kernel/bitnet_kernel.h
|
14 |
+
]
|
15 |
+
depends = [ "torch"]
|
flake.nix
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Torch kernel extension";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs = { self, kernel-builder, }:
|
9 |
+
kernel-builder.lib.genFlakeOutputs {
|
10 |
+
path = ./.;
|
11 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
12 |
+
};
|
13 |
+
}
|
torch-ext/bitnet_kernel/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ._ops import ops
|
2 |
+
|
3 |
+
def bitnet_int8xint2_linear(input0, input1, s, ws):
|
4 |
+
return ops.bitlinear_int8xint2_cpp(input0, input1, s, ws)
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("bitlinear_int8xint2_cpp(Tensor input0, Tensor input1, Tensor s, Tensor ws) -> Tensor");
|
8 |
+
ops.impl("bitlinear_int8xint2_cpp", torch::kCUDA, &bitlinear_int8xint2_cpp);
|
9 |
+
}
|
10 |
+
|
11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void bitlinear_int8xint2_cpp(torch::Tensor const &input0, torch::Tensor const &input1, torch::Tensor const &s, torch::Tensor const &ws);
|