medmekk HF Staff commited on
Commit
d538a8a
·
verified ·
1 Parent(s): 827249e

Upload custom kernels

Browse files
bitnet_kernel/bitnet_kernel.cu ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "bitnet_kernels.h"
2
+ #include <torch/all.h>
3
+
4
+ extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){
5
+ if (M == 1 && N == 3840 && K == 2560){
6
+ ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<<dim3(240, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
7
+ }
8
+ else if (M == 1 && N == 2560 && K == 2560){
9
+ ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
10
+ }
11
+ else if (M == 1 && N == 13824 && K == 2560){
12
+ ladder_int8xint2_kernel<1, 13824, 2560, 2, 8, 16><<<dim3(864, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
13
+ }
14
+ else if (M == 1 && N == 2560 && K == 6912){
15
+ ladder_int8xint2_kernel<1, 2560, 6912, 1, 8, 16><<<dim3(160, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
16
+ }
17
+ else if(M == 1 && N == 4800 && K == 3200){
18
+ ladder_int8xint2_kernel<1, 4800, 3200, 6, 8, 16><<<dim3(300, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
19
+ }
20
+ else if(M == 1 && N == 3200 && K == 3200){
21
+ ladder_int8xint2_kernel<1, 3200, 3200, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
22
+ }
23
+ else if(M == 1 && N == 20480 && K == 3200){
24
+ ladder_int8xint2_kernel<1, 20480, 3200, 2, 8, 16><<<dim3(1280, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
25
+ }
26
+ else if(M == 1 && N == 3200 && K == 10240){
27
+ ladder_int8xint2_kernel<1, 3200, 10240, 1, 8, 16><<<dim3(200, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
28
+ }
29
+ else if(M == 1 && N == 5120 && K == 27648){
30
+ ladder_int8xint2_kernel<1, 5120, 27648, 1, 8, 16><<<dim3(320, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
31
+ }
32
+ else if(M == 1 && N == 55296 && K == 5120){
33
+ ladder_int8xint2_kernel<1, 55296, 5120, 1, 8, 16><<<dim3(3456, 1, 1), dim3(8, 16, 1), 0, stream>>>(input0, input1, output0, s, ws);
34
+ }
35
+ else{
36
+ std::cout << "required ladder gemm kernel: M " << M << ", N " << N << ", K " << K << std::endl;
37
+ }
38
+ }
39
+
40
+ torch::Tensor bitlinear_int8xint2_cpp(torch::Tensor input0, torch::Tensor input1, torch::Tensor s, torch::Tensor ws) {
41
+ // Get input dimensions
42
+ auto out_shape = input0.sizes().vec();
43
+ out_shape.back() = input1.size(0);
44
+
45
+ // Calculate M, N, K
46
+ int M = input0.size(0);
47
+ if (out_shape.size() == 3) {
48
+ M *= input0.size(1);
49
+ }
50
+ int N = input1.size(0);
51
+ int K = input1.size(1) * 4;
52
+
53
+ // Create output tensor
54
+ auto output = torch::zeros(out_shape, torch::dtype(torch::kBFloat16).device(input0.device()));
55
+
56
+ // Get CUDA stream
57
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
58
+
59
+ // Call kernel
60
+ bitlinear_int8xint2(
61
+ reinterpret_cast<int8_t*>(input0.data_ptr()),
62
+ reinterpret_cast<int8_t*>(input1.data_ptr()),
63
+ reinterpret_cast<__nv_bfloat16*>(output.data_ptr()),
64
+ reinterpret_cast<__nv_bfloat16*>(s.data_ptr()),
65
+ reinterpret_cast<__nv_bfloat16*>(ws.data_ptr()),
66
+ M, N, K, stream
67
+ );
68
+
69
+ return output;
70
+ }
71
+
72
+
bitnet_kernel/bitnet_kernel.h ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_runtime.h>
2
+ #include <math_constants.h>
3
+ #include <math.h>
4
+ #include <mma.h>
5
+ #include <iostream>
6
+ #include <cuda.h>
7
+ #include <cuda_fp16.h>
8
+ #include <cuda_bf16.h>
9
+
10
+
11
+ #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || (__CUDACC_VER_MAJOR__ > 11))
12
+ #define TVM_ENABLE_L2_PREFETCH 1
13
+ #else
14
+ #define TVM_ENABLE_L2_PREFETCH 0
15
+ #endif
16
+
17
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 800
18
+ #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 1
19
+ #else
20
+ #define TVM_ENBALE_EFFICIENT_SMEM_PTR_CAST 0
21
+ #endif
22
+
23
+ template <typename T1, typename T2>
24
+ __device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16)
25
+ {
26
+ // convert 8 int2b_t to 8 int8b_t -> 2 int32
27
+ uint *i8s = reinterpret_cast<uint *>(_i8s);
28
+
29
+ // i2s = {e0, e4, e8, e12, e1, e5, e9, e13, e2, e6, e10, e14, e3, e7, e11, e15}
30
+ uint const i2s = *_i2s;
31
+
32
+ static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010
33
+ static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3
34
+ static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000;
35
+
36
+ #pragma unroll
37
+ for (int i = 0; i < (N / 4); i++)
38
+ {
39
+ asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
40
+ : "=r"(i8s[i])
41
+ : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut));
42
+ i8s[i] = __vsubss4(i8s[i], 0x02020202);
43
+ }
44
+ }
45
+
46
+ template <int M, int N, int K, int ws_num, int K_block_size, int N_block_size>
47
+ __global__ void __launch_bounds__(128) ladder_int8xint2_kernel(int8_t* __restrict__ A, int8_t* __restrict__ B, __nv_bfloat16* __restrict__ dtype_transform, __nv_bfloat16* __restrict__ s, __nv_bfloat16* __restrict__ ws) {
48
+ constexpr int K_per_loop = 16;
49
+ constexpr int wmma_K = 32;
50
+ constexpr int wmma_N = 16;
51
+ int in_thread_C_local[1];
52
+ signed char A_local[K_per_loop];
53
+ int B_reshape_local[1];
54
+ signed char B_decode_local[K_per_loop];
55
+ int red_buf0[1];
56
+ in_thread_C_local[0] = 0;
57
+ #pragma unroll
58
+ for (int k_0 = 0; k_0 < K/(K_per_loop * K_block_size); ++k_0) {
59
+ *(int4*)(A_local + 0) = *(int4*)(A + ((k_0 * K_per_loop * K_block_size) + (((int)threadIdx.x) * K_per_loop)));
60
+ B_reshape_local[0] = *(int*)(B +
61
+ (((int)blockIdx.x) * N_block_size * K / 4) +
62
+ (k_0 * K_block_size * K_per_loop * wmma_N / 4) +
63
+ ((((int)threadIdx.x) >> 1) * wmma_K * wmma_N / 4) +
64
+ ((((int)threadIdx.y) >> 3) * (wmma_K * wmma_N / 2) / 4) +
65
+ ((((int)threadIdx.x) & 1) * (wmma_K * wmma_N / 4) / 4) +
66
+ ((((int)threadIdx.y) & 7) * (wmma_K / 2) / 4)
67
+ );
68
+ decode_i2s_to_i8s(B_reshape_local, B_decode_local, 16);
69
+ #pragma unroll
70
+ for (int k_2_0 = 0; k_2_0 < 4; ++k_2_0) {
71
+ in_thread_C_local[0] = __dp4a(*(int *)&A_local[((k_2_0 * 4))],*(int *)&B_decode_local[((k_2_0 * 4))], in_thread_C_local[0]);
72
+ }
73
+ }
74
+ red_buf0[0] = in_thread_C_local[0];
75
+ #pragma unroll
76
+ for (int offset = K_block_size/2; offset > 0; offset /= 2) {
77
+ red_buf0[0] += __shfl_down_sync(__activemask(), red_buf0[0], offset, K_block_size);
78
+ }
79
+ int out_idx = ((((int)blockIdx.x) * N_block_size) + ((int)threadIdx.y));
80
+ int ws_idx = out_idx / (N / ws_num);
81
+ if (threadIdx.x == 0)
82
+ dtype_transform[out_idx] = (__nv_bfloat16)(((float)red_buf0[0])/(float)s[0]*(float)ws[ws_idx]);
83
+ }