danieldk HF Staff commited on
Commit
5a84343
·
0 Parent(s):

Add Punica sgmv kernels

Browse files
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## Punica sgmv
8
+
9
+ [Punica](https://github.com/punica-ai/punica) sgmv kernels with modifications
10
+ from [Lorax](https://github.com/predibase/lorax).
11
+
bgmv/bgmv_all.cu ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #include "bgmv_config.h"
2
+ #include "bgmv_impl.cuh"
3
+
4
+ FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half)
5
+ FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16)
bgmv/bgmv_config.h ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ template <int feat_in, int feat_out, typename T>
4
+ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X,
5
+ T **__restrict__ W,
6
+ const int64_t *__restrict__ indicies, int64_t y_offset,
7
+ int64_t full_y_size, int64_t batch_size,
8
+ int64_t layer_idx, float scale);
9
+
10
+ // clang-format off
11
+
12
+ #define FOR_BGMV_WIDE(f, T, narrow) \
13
+ f(T, narrow, 256) \
14
+ f(T, narrow, 512) \
15
+ f(T, narrow, 640) \
16
+ f(T, narrow, 768) \
17
+ f(T, narrow, 1024) \
18
+ f(T, narrow, 1152) \
19
+ f(T, narrow, 1280) \
20
+ f(T, narrow, 1536) \
21
+ f(T, narrow, 1728) \
22
+ f(T, narrow, 1792) \
23
+ f(T, narrow, 2048) \
24
+ f(T, narrow, 2304) \
25
+ f(T, narrow, 2560) \
26
+ f(T, narrow, 2752) \
27
+ f(T, narrow, 2816) \
28
+ f(T, narrow, 3072) \
29
+ f(T, narrow, 3456) \
30
+ f(T, narrow, 3584) \
31
+ f(T, narrow, 4096) \
32
+ f(T, narrow, 4480) \
33
+ f(T, narrow, 4608) \
34
+ f(T, narrow, 5120) \
35
+ f(T, narrow, 5504) \
36
+ f(T, narrow, 5632) \
37
+ f(T, narrow, 6144) \
38
+ f(T, narrow, 6848) \
39
+ f(T, narrow, 6912) \
40
+ f(T, narrow, 7168) \
41
+ f(T, narrow, 7680) \
42
+ f(T, narrow, 8192) \
43
+ f(T, narrow, 8960) \
44
+ f(T, narrow, 9216) \
45
+ f(T, narrow, 9472) \
46
+ f(T, narrow, 10240) \
47
+ f(T, narrow, 11008) \
48
+ f(T, narrow, 12288) \
49
+ f(T, narrow, 13696) \
50
+ f(T, narrow, 13824) \
51
+ f(T, narrow, 14336) \
52
+ f(T, narrow, 15360) \
53
+ f(T, narrow, 16384) \
54
+ f(T, narrow, 17920) \
55
+ f(T, narrow, 18944) \
56
+ f(T, narrow, 20480) \
57
+ f(T, narrow, 22016) \
58
+ f(T, narrow, 24576) \
59
+ f(T, narrow, 27392) \
60
+ f(T, narrow, 27648) \
61
+ f(T, narrow, 28672) \
62
+ f(T, narrow, 32000) \
63
+ f(T, narrow, 32256) \
64
+ f(T, narrow, 32512) \
65
+ f(T, narrow, 32768) \
66
+ f(T, narrow, 33024) \
67
+ f(T, narrow, 35840) \
68
+ f(T, narrow, 36864) \
69
+ f(T, narrow, 43264) \
70
+ f(T, narrow, 49152) \
71
+ f(T, narrow, 64000) \
72
+ f(T, narrow, 64256) \
73
+ f(T, narrow, 64512) \
74
+ f(T, narrow, 102400) \
75
+ f(T, narrow, 102656) \
76
+ f(T, narrow, 102912) \
77
+ f(T, narrow, 128000) \
78
+ f(T, narrow, 128256) \
79
+ f(T, narrow, 128512) \
80
+
81
+ #define FOR_BGMV_WIDE_NARROW(f, T) \
82
+ FOR_BGMV_WIDE(f, T, 8) \
83
+ FOR_BGMV_WIDE(f, T, 16) \
84
+ FOR_BGMV_WIDE(f, T, 32) \
85
+ FOR_BGMV_WIDE(f, T, 64) \
86
+ FOR_BGMV_WIDE(f, T, 128)
87
+
88
+ // clang-format on
bgmv/bgmv_impl.cuh ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <cooperative_groups.h>
5
+ #include <cuda/pipeline>
6
+ #include <cuda_runtime.h>
7
+ #include <iostream>
8
+ #include <stdio.h>
9
+
10
+ #include "flashinfer/vec_dtypes.cuh"
11
+
12
+ namespace cg = cooperative_groups;
13
+
14
+ // nthrs = (32, 4)
15
+ template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
16
+ size_t W_copy_size, int tx, int ty, int tz, typename T>
17
+ __global__ void
18
+ bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X,
19
+ T** __restrict__ W,
20
+ const int64_t* __restrict__ indicies, int64_t y_offset,
21
+ int64_t full_y_size, int64_t layer_idx,
22
+ float scale) {
23
+ size_t batch_idx = blockIdx.y;
24
+ int64_t idx = indicies[batch_idx];
25
+ if (idx < 0) {
26
+ return;
27
+ }
28
+
29
+ auto block = cg::this_thread_block();
30
+ size_t j = blockIdx.x;
31
+ constexpr size_t num_pipeline_stages = 2;
32
+ constexpr size_t tile_size = tx * ty * vec_size;
33
+ __shared__ T W_shared[num_pipeline_stages * tile_size];
34
+ __shared__ T X_shared[num_pipeline_stages * tile_size];
35
+ __shared__ float y_warpwise[ty];
36
+
37
+ size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
38
+ size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
39
+ auto pipe = cuda::make_pipeline();
40
+
41
+ const T* W_ptr = W[idx];
42
+
43
+ // pipeline load W/X and compute WX;
44
+ pipe.producer_acquire();
45
+ cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
46
+ W_ptr + (layer_idx * feat_out + j) * feat_in +
47
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
48
+ cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
49
+ cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
50
+ X + (batch_idx * feat_in) +
51
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
52
+ cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
53
+ pipe.producer_commit();
54
+ size_t copy_idx, compute_idx;
55
+ float y = 0.f;
56
+ flashinfer::vec_t<T, vec_size> x_vec;
57
+ flashinfer::vec_t<T, vec_size> w_vec;
58
+ size_t tile_idx;
59
+
60
+ #pragma unroll
61
+ for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
62
+ ++tile_idx) {
63
+ copy_idx = tile_idx % num_pipeline_stages;
64
+ // pipeline stage: async copy W fragment
65
+ pipe.producer_acquire();
66
+ if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
67
+ cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
68
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
69
+ W_ptr + (layer_idx * feat_out + j) * feat_in +
70
+ tile_idx * tile_size +
71
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
72
+ cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
73
+ cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
74
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
75
+ X + (batch_idx * feat_in) + tile_idx * tile_size +
76
+ (threadIdx.y * tx + threadIdx.x) * vec_size,
77
+ cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
78
+ }
79
+ pipe.producer_commit();
80
+
81
+ compute_idx = (tile_idx - 1) % num_pipeline_stages;
82
+ // pipeline stage: compute WX
83
+ pipe.consumer_wait();
84
+ block.sync();
85
+ x_vec.load(X_shared + X_shared_offset[compute_idx] +
86
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
87
+ w_vec.load(W_shared + W_shared_offset[compute_idx] +
88
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
89
+ float sum = 0.f;
90
+ #pragma unroll
91
+ for (size_t i = 0; i < vec_size; ++i) {
92
+ sum += float(w_vec[i]) * float(x_vec[i]) * scale;
93
+ }
94
+ #pragma unroll
95
+ for (size_t offset = tx / 2; offset > 0; offset /= 2) {
96
+ sum += __shfl_down_sync(0xffffffff, sum, offset);
97
+ }
98
+ y_warpwise[threadIdx.y] = sum;
99
+ block.sync();
100
+ #pragma unroll
101
+ for (size_t i = 0; i < ty; ++i) {
102
+ y += y_warpwise[i];
103
+ }
104
+
105
+ block.sync();
106
+ pipe.consumer_release();
107
+ }
108
+
109
+ compute_idx = (tile_idx - 1) % num_pipeline_stages;
110
+ // final pipeline stage
111
+ pipe.consumer_wait();
112
+ block.sync();
113
+ x_vec.load(X_shared + X_shared_offset[compute_idx] +
114
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
115
+ w_vec.load(W_shared + W_shared_offset[compute_idx] +
116
+ (threadIdx.y * tx + threadIdx.x) * vec_size);
117
+ float sum = 0.f;
118
+ #pragma unroll
119
+ for (size_t i = 0; i < vec_size; ++i) {
120
+ sum += float(w_vec[i]) * float(x_vec[i]) * scale;
121
+ }
122
+ #pragma unroll
123
+ for (size_t offset = tx / 2; offset > 0; offset /= 2) {
124
+ sum += __shfl_down_sync(0xffffffff, sum, offset);
125
+ }
126
+ y_warpwise[threadIdx.y] =
127
+ ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
128
+ ? sum
129
+ : 0.f;
130
+ block.sync();
131
+ #pragma unroll
132
+ for (size_t i = 0; i < ty; ++i) {
133
+ y += y_warpwise[i];
134
+ }
135
+
136
+ block.sync();
137
+ pipe.consumer_release();
138
+
139
+ // write Y;
140
+ if (block.thread_rank() == 0) {
141
+ Y[batch_idx * full_y_size + y_offset + j] += static_cast<T>(y);
142
+ }
143
+ }
144
+
145
+ // nthrs = (2, 16, 4)
146
+ template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
147
+ typename T>
148
+ __global__ void
149
+ bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X,
150
+ T** __restrict__ W,
151
+ const int64_t* __restrict__ indicies, int64_t y_offset,
152
+ int64_t full_y_size, int64_t layer_idx,
153
+ float scale) {
154
+ size_t batch_idx = blockIdx.y;
155
+ int64_t idx = indicies[batch_idx];
156
+
157
+ if (idx < 0) {
158
+ return;
159
+ }
160
+
161
+ auto block = cg::this_thread_block();
162
+ size_t tile_idx = blockIdx.x;
163
+
164
+ const T* W_ptr = W[idx];
165
+
166
+ // load X;
167
+ flashinfer::vec_t<T, vec_size> x_vec;
168
+ x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
169
+
170
+ // load W;
171
+ flashinfer::vec_t<T, vec_size> w_vec;
172
+ w_vec.load(W_ptr + (layer_idx * feat_out + tile_idx * tz * ty) * feat_in +
173
+ block.thread_rank() * vec_size);
174
+
175
+ float sum = 0.f;
176
+ #pragma unroll
177
+ for (size_t i = 0; i < vec_size; ++i) {
178
+ sum += float(w_vec[i]) * float(x_vec[i]) * scale;
179
+ }
180
+
181
+ cg::thread_block_tile g = cg::tiled_partition<tx>(block);
182
+ #pragma unroll
183
+ for (size_t offset = tx / 2; offset > 0; offset /= 2) {
184
+ sum += g.shfl_down(sum, offset);
185
+ }
186
+ sum = g.shfl(sum, 0);
187
+
188
+ if (threadIdx.x == 0) {
189
+ Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
190
+ threadIdx.z * ty + threadIdx.y] += static_cast<T>(sum);
191
+ }
192
+ }
193
+
194
+ template <int feat_in, int feat_out, typename T>
195
+ void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X,
196
+ T** __restrict__ W,
197
+ const int64_t* __restrict__ indicies, int64_t y_offset,
198
+ int64_t full_y_size, int64_t batch_size,
199
+ int64_t layer_idx, float scale) {
200
+ constexpr size_t vec_size = 8;
201
+ constexpr int tz = 4;
202
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
203
+
204
+ if constexpr (feat_in < feat_out) {
205
+ static_assert(feat_in % vec_size == 0);
206
+ constexpr int tx = feat_in / vec_size;
207
+
208
+ static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
209
+ (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
210
+ (8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
211
+
212
+ if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
213
+ constexpr int ty = 32 / tx;
214
+ dim3 nblks(feat_out / (ty * tz), batch_size);
215
+ dim3 nthrs(tx, ty, tz);
216
+
217
+ bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
218
+ <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
219
+ full_y_size, layer_idx,
220
+ scale);
221
+ } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
222
+ constexpr int ty = 16 / tx;
223
+ dim3 nblks(feat_out / (ty * tz), batch_size);
224
+ dim3 nthrs(tx, ty, tz);
225
+
226
+ bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
227
+ <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
228
+ full_y_size, layer_idx,
229
+ scale);
230
+ } else {
231
+ constexpr int ty = 8 / tx;
232
+ dim3 nblks(feat_out / (ty * tz), batch_size);
233
+ dim3 nthrs(tx, ty, tz);
234
+
235
+ bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
236
+ <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
237
+ full_y_size, layer_idx,
238
+ scale);
239
+ }
240
+ } else {
241
+ static_assert(feat_in % (vec_size * 32) == 0 ||
242
+ feat_in % (vec_size * 16) == 0 ||
243
+ feat_in % (vec_size * 8) == 0);
244
+
245
+ if constexpr (feat_in % (vec_size * 32) == 0) {
246
+ constexpr int tx = 32;
247
+ constexpr int ty = 4;
248
+
249
+ dim3 nblks(feat_out, batch_size);
250
+ dim3 nthrs(tx, ty);
251
+
252
+ bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(T),
253
+ vec_size * sizeof(T), tx, ty, tz>
254
+ <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
255
+ full_y_size, layer_idx,
256
+ scale);
257
+ } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
258
+ constexpr int tx = 32;
259
+ constexpr int ty = 4;
260
+
261
+ dim3 nblks(feat_out, batch_size);
262
+ dim3 nthrs(tx, ty);
263
+
264
+ bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
265
+ vec_size * sizeof(T) / 2,
266
+ vec_size * sizeof(T) / 2, tx, ty, tz>
267
+ <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
268
+ full_y_size, layer_idx,
269
+ scale);
270
+ } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
271
+ constexpr int tx = 16;
272
+ constexpr int ty = 4;
273
+
274
+ dim3 nblks(feat_out, batch_size);
275
+ dim3 nthrs(tx, ty);
276
+
277
+ bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
278
+ vec_size * sizeof(T) / 2,
279
+ vec_size * sizeof(T) / 2, tx, ty, tz>
280
+ <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
281
+ full_y_size, layer_idx,
282
+ scale);
283
+ }
284
+ }
285
+ }
286
+
287
+ #define INST_BGMV(feat_in, feat_out, T) \
288
+ template void bgmv_kernel<feat_in, feat_out>( \
289
+ T* __restrict__ Y, const T* __restrict__ X, \
290
+ T** __restrict__ W, const int64_t* __restrict__ indicies, \
291
+ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
292
+ int64_t layer_idx, float scale);
293
+
294
+ #define INST_BGMV_TWOSIDE(T, narrow, wide) \
295
+ INST_BGMV(narrow, wide, T) \
296
+ INST_BGMV(wide, narrow, T)
build.toml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "punica_sgmv"
3
+
4
+ [torch]
5
+ src = [
6
+ "torch-ext/torch_binding.cpp",
7
+ "torch-ext/torch_binding.h"
8
+ ]
9
+
10
+ [kernel.sgmv]
11
+ language = "cuda"
12
+ src = [
13
+ "sgmv/sgmv_cutlass.cu",
14
+ "sgmv/sgmv_cutlass.cuh",
15
+ ]
16
+ depends = [ "cutlass_3_8", "torch" ]
17
+
18
+ [kernel.sgmv_flashinfer]
19
+ language = "cuda"
20
+ cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
21
+ src = [
22
+ "flashinfer/cp_async.cuh",
23
+ "flashinfer/mma.cuh",
24
+ "flashinfer/permuted_smem.cuh",
25
+ "flashinfer/vec_dtypes.cuh",
26
+ "sgmv_flashinfer/sgmv_all.cu",
27
+ "sgmv_flashinfer/sgmv_config.h",
28
+ "sgmv_flashinfer/sgmv_flashinfer.cuh"
29
+ ]
30
+ include = [ "." ]
31
+ depends = [ "torch" ]
32
+
33
+ [kernel.bgmv]
34
+ language = "cuda"
35
+ src = [
36
+ "bgmv/bgmv_all.cu",
37
+ "bgmv/bgmv_impl.cuh",
38
+ "bgmv/bgmv_config.h",
39
+ "flashinfer/vec_dtypes.cuh"
40
+ ]
41
+ include = [ "." ]
42
+ depends = [ "torch" ]
43
+
44
+ [kernel.punica_kernels]
45
+ language = "cuda"
46
+ src = [
47
+ "bgmv/bgmv_config.h",
48
+ "punica_kernels/punica_ops.cc",
49
+ "sgmv/sgmv.h",
50
+ "sgmv_flashinfer/sgmv_config.h"
51
+ ]
52
+ include = [ "." ]
53
+ depends = [ "torch" ]
flake.lock ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs",
41
+ "rocm-nix": "rocm-nix"
42
+ },
43
+ "locked": {
44
+ "lastModified": 1747143871,
45
+ "narHash": "sha256-gXYPmA7wBqcTy1+39Z/UAIZ5mCSl9W5IoAvDQhIezec=",
46
+ "owner": "huggingface",
47
+ "repo": "kernel-builder",
48
+ "rev": "a78a83cfb31373e0782921999e1917b7f91af7d3",
49
+ "type": "github"
50
+ },
51
+ "original": {
52
+ "owner": "huggingface",
53
+ "repo": "kernel-builder",
54
+ "type": "github"
55
+ }
56
+ },
57
+ "nixpkgs": {
58
+ "locked": {
59
+ "lastModified": 1746711195,
60
+ "narHash": "sha256-bSpM2ySq12PBOVN7jZdzXsc99iRoYOyolh5wz43+CjQ=",
61
+ "owner": "danieldk",
62
+ "repo": "nixpkgs",
63
+ "rev": "6b7a66b06ccb09ac95872ac6ddf952e0660672ab",
64
+ "type": "github"
65
+ },
66
+ "original": {
67
+ "owner": "danieldk",
68
+ "ref": "kernel-builder-cuda-12.9.0",
69
+ "repo": "nixpkgs",
70
+ "type": "github"
71
+ }
72
+ },
73
+ "rocm-nix": {
74
+ "inputs": {
75
+ "nixpkgs": [
76
+ "kernel-builder",
77
+ "nixpkgs"
78
+ ]
79
+ },
80
+ "locked": {
81
+ "lastModified": 1745310663,
82
+ "narHash": "sha256-1U3PzCO/jt7HUlEgLOY3RpxadKwTo6GSvb2j4m0UFw0=",
83
+ "owner": "huggingface",
84
+ "repo": "rocm-nix",
85
+ "rev": "e08373a0efa1c297b0c57af070e0a311df47481f",
86
+ "type": "github"
87
+ },
88
+ "original": {
89
+ "owner": "huggingface",
90
+ "repo": "rocm-nix",
91
+ "type": "github"
92
+ }
93
+ },
94
+ "root": {
95
+ "inputs": {
96
+ "kernel-builder": "kernel-builder"
97
+ }
98
+ },
99
+ "systems": {
100
+ "locked": {
101
+ "lastModified": 1681028828,
102
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
103
+ "owner": "nix-systems",
104
+ "repo": "default",
105
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "nix-systems",
110
+ "repo": "default",
111
+ "type": "github"
112
+ }
113
+ }
114
+ },
115
+ "root": "root",
116
+ "version": 7
117
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Punica SGMV kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
flashinfer/cp_async.cuh ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023 by FlashInfer team.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef FLASHINFER_CP_ASYNC_CUH_
17
+ #define FLASHINFER_CP_ASYNC_CUH_
18
+
19
+ #include <cuda_runtime.h>
20
+
21
+ namespace flashinfer {
22
+
23
+ namespace cp_async {
24
+
25
+ enum class SharedMemFillMode {
26
+ kFillZero, // Fill zero to shared memory when predicate is false
27
+ kNoFill // Do not fill zero to shared memory when predicate is false
28
+ };
29
+
30
+ enum class PrefetchMode {
31
+ kNoPrefetch, // Do not fetch additional data from global memory to L2
32
+ kPrefetch // Fetch additional data from global memory to L2
33
+ };
34
+
35
+ #if (__CUDACC_VER_MAJOR__ >= 11)
36
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
37
+ #define FLASHINFER_CP_ASYNC_ENABLED
38
+ #endif
39
+ #endif
40
+
41
+ /*!
42
+ * \brief Wrapper of PTX cp.async.commit_group instruction, commit all prior uncommitted
43
+ * cp.async instructions to a group
44
+ */
45
+ __device__ __forceinline__ void commit_group() {
46
+ #ifdef FLASHINFER_CP_ASYNC_ENABLED
47
+ asm volatile("cp.async.commit_group;\n" ::);
48
+ #endif
49
+ }
50
+
51
+ /*!
52
+ * \brief Wrapper of PTX cp.async.wait_group instruction
53
+ * \tparam n Wait till most recent n groups are committed
54
+ */
55
+ template <size_t n>
56
+ __device__ __forceinline__ void wait_group() {
57
+ #ifdef FLASHINFER_CP_ASYNC_ENABLED
58
+ asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
59
+ #endif
60
+ }
61
+
62
+ /*!
63
+ * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from
64
+ * global memory to shared memory
65
+ * \tparam prefetch_mode Whether to fetch additional data from global memory to L2
66
+ * \tparam T Data type
67
+ * \param smem_ptr Pointer to shared memory
68
+ * \param gmem_ptr Pointer to global memory
69
+ */
70
+ template <PrefetchMode prefetch_mode, typename T>
71
+ __device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
72
+ #ifdef FLASHINFER_CP_ASYNC_ENABLED
73
+ uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
74
+ if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
75
+ asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
76
+ "l"(gmem_ptr), "n"(16), "r"(16));
77
+ } else {
78
+ asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
79
+ "l"(gmem_ptr), "n"(16), "r"(16));
80
+ }
81
+ #else
82
+ *((uint4*)smem_ptr) = *((uint4*)gmem_ptr);
83
+ #endif
84
+ }
85
+
86
+ /*!
87
+ * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from
88
+ * global memory to shared memory with predicate.
89
+ * \tparam prefetch_mode Whether to fetch additional data from global memory to L2
90
+ * \tparam fill_mode Whether to fill zero to shared memory when predicate is false
91
+ * \tparam T Data type
92
+ * \param smem_ptr Pointer to shared memory
93
+ * \param gmem_ptr Pointer to global memory
94
+ * \param predicate Predicate value
95
+ * \note fill zero is slower than not fill zero
96
+ */
97
+ template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
98
+ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) {
99
+ #ifdef FLASHINFER_CP_ASYNC_ENABLED
100
+ uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
101
+ if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
102
+ int src_in_bytes = predicate ? 16 : 0;
103
+ if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
104
+ asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
105
+ "l"(gmem_ptr), "n"(16), "r"(src_in_bytes));
106
+ } else {
107
+ asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
108
+ "l"(gmem_ptr), "n"(16), "r"(src_in_bytes));
109
+ }
110
+ } else {
111
+ if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
112
+ asm volatile(
113
+ "{\n"
114
+ " .reg .pred p;\n"
115
+ " setp.ne.b32 p, %0, 0;\n"
116
+ " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
117
+ "}\n" ::"r"((int)predicate),
118
+ "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16));
119
+ } else {
120
+ asm volatile(
121
+ "{\n"
122
+ " .reg .pred p;\n"
123
+ " setp.ne.b32 p, %0, 0;\n"
124
+ " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
125
+ "}\n" ::"r"((int)predicate),
126
+ "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16));
127
+ }
128
+ }
129
+ #else
130
+ if (predicate) {
131
+ *((uint4*)smem_ptr) = *((uint4*)gmem_ptr);
132
+ } else {
133
+ if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
134
+ *((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0);
135
+ }
136
+ }
137
+ #endif
138
+ }
139
+
140
+ /*!
141
+ * \brief Load specified number of bits per thread from global memory to shared memory
142
+ * \tparam num_bits Number of bits to load, must be 128 or 256
143
+ * \tparam prefetch_mode Whether to fetch additional data from global memory to L2
144
+ * \tparam T Data type
145
+ * \param smem_ptr Pointer to shared memory
146
+ * \param gmem_ptr Pointer to global memory
147
+ */
148
+ template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
149
+ __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) {
150
+ static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256");
151
+ if constexpr (num_bits == 128) {
152
+ load_128b<prefetch_mode>(smem_ptr, gmem_ptr);
153
+ } else {
154
+ load_128b<prefetch_mode>(smem_ptr, gmem_ptr);
155
+ load_128b<prefetch_mode>(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T));
156
+ }
157
+ }
158
+
159
+ /*!
160
+ * \brief Load specified number of bits per thread from global memory to shared memory with
161
+ * predicate
162
+ * \tparam num_bits Number of bits to load, must be 128 or 256
163
+ * \tparam prefetch_mode Whether to fetch additional data from global memory to L2
164
+ * \tparam fill_mode Whether to fill zero to shared memory when predicate is false
165
+ * \tparam T Data type
166
+ * \param smem_ptr Pointer to shared memory
167
+ * \param gmem_ptr Pointer to global memory
168
+ * \param predicate Predicate value
169
+ * \note fill zero is slower than not fill zero
170
+ */
171
+ template <size_t num_bits, PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
172
+ __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) {
173
+ static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256");
174
+ if constexpr (num_bits == 128) {
175
+ pred_load_128b<prefetch_mode, fill_mode>(smem_ptr, gmem_ptr, predicate);
176
+ } else {
177
+ pred_load_128b<prefetch_mode, fill_mode>(smem_ptr, gmem_ptr, predicate);
178
+ pred_load_128b<prefetch_mode, fill_mode>(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T),
179
+ predicate);
180
+ }
181
+ }
182
+
183
+ } // namespace cp_async
184
+
185
+ } // namespace flashinfer
186
+
187
+ #endif // FLASHINFER_CP_ASYNC_CUH_
flashinfer/mma.cuh ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023 by FlashInfer team.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef FLASHINFER_MMA_CUH_
17
+ #define FLASHINFER_MMA_CUH_
18
+
19
+ #include <cuda_bf16.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_runtime.h>
22
+
23
+ #include <type_traits>
24
+
25
+ namespace flashinfer {
26
+
27
+ namespace mma {
28
+
29
+ #if (__CUDACC_VER_MAJOR__ >= 11)
30
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
31
+ #define FLASHINFER_STMATRIX_M8N8X4_ENABLED
32
+ #endif
33
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
34
+ #define FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED
35
+ #define FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED
36
+ #endif
37
+ #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750))
38
+ #define FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED
39
+ #define FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED
40
+ #define FLASHINFER_LDMATRIX_M8N8X4_ENABLED
41
+ #endif
42
+ #endif
43
+
44
+ enum class MMAMode {
45
+ kInit = 0U,
46
+ kInplaceUpdate = 1U,
47
+ };
48
+
49
+ /*!
50
+ * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory
51
+ * to fragment
52
+ * \tparam T data type of the fragment
53
+ * \param R pointer to the fragment
54
+ * \param smem_ptr pointer to the shared memory
55
+ */
56
+ template <typename T>
57
+ __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
58
+ #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED
59
+ uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
60
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
61
+ : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3])
62
+ : "r"(smem_int_ptr));
63
+ #else
64
+ #error "Unsupported CUDA architecture for ldmatrix instruction"
65
+ #endif
66
+ }
67
+
68
+ /*!
69
+ * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from
70
+ * shared memory to fragment and transposes the fragment
71
+ * \tparam T data type of the fragment
72
+ * \param R pointer to the fragment
73
+ * \param smem_ptr pointer to the shared memory
74
+ */
75
+ template <typename T>
76
+ __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) {
77
+ #ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED
78
+ uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
79
+ asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
80
+ : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3])
81
+ : "r"(smem_int_ptr));
82
+ #else
83
+ #error "Unsupported CUDA architecture for ldmatrix instruction"
84
+ #endif
85
+ }
86
+
87
+ /*!
88
+ * \brief Wrapper of PTX stmatrix m8n8.x4 instruction, stores data from fragment
89
+ * to shared memory
90
+ * \tparam T data type of the fragment
91
+ * \param R pointer to the fragment
92
+ * \param smem_ptr pointer to the shared memory
93
+ */
94
+ template <typename T>
95
+ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
96
+ #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
97
+ uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
98
+ asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n"
99
+ :
100
+ : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3]));
101
+ #else
102
+ // Fallback implementation, slower than PTX instruction
103
+ const uint32_t tx = threadIdx.x;
104
+ uint4 word;
105
+ #pragma unroll
106
+ for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) {
107
+ word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4);
108
+ word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1);
109
+ word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2);
110
+ word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3);
111
+ if (tx / 8 == reg_id) {
112
+ *(uint4*)smem_ptr = word;
113
+ }
114
+ }
115
+ #endif
116
+ }
117
+
118
+ /*!
119
+ * \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix
120
+ * multiplication, accumulated in f32.
121
+ * \tparam T data type of the fragment
122
+ * \tparam mma_mode whether we are initializing the accumulator or updating it
123
+ * \param C pointer to the accumulator
124
+ * \param A pointer to the fragment of matrix A
125
+ * \param B pointer to the fragment of matrix B
126
+ */
127
+ template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
128
+ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A,
129
+ uint32_t* B) {
130
+ #if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED)
131
+ if constexpr (mma_mode == MMAMode::kInit) {
132
+ if constexpr (std::is_same<T, half>::value) {
133
+ asm volatile(
134
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
135
+ "{%0, %1, %2, %3},"
136
+ "{%4, %5, %6, %7},"
137
+ "{%8, %9},"
138
+ "{%10, %11, %12, %13};\n"
139
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
140
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f),
141
+ "f"(0.f), "f"(0.f));
142
+ asm volatile(
143
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
144
+ "{%0, %1, %2, %3},"
145
+ "{%4, %5, %6, %7},"
146
+ "{%8, %9},"
147
+ "{%10, %11, %12, %13};\n"
148
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
149
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f),
150
+ "f"(0.f), "f"(0.f));
151
+ } else {
152
+ asm volatile(
153
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
154
+ "{%0, %1, %2, %3},"
155
+ "{%4, %5, %6, %7},"
156
+ "{%8, %9},"
157
+ "{%10, %11, %12, %13};\n"
158
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
159
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f),
160
+ "f"(0.f), "f"(0.f));
161
+ asm volatile(
162
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
163
+ "{%0, %1, %2, %3},"
164
+ "{%4, %5, %6, %7},"
165
+ "{%8, %9},"
166
+ "{%10, %11, %12, %13};\n"
167
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
168
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f),
169
+ "f"(0.f), "f"(0.f));
170
+ }
171
+ } else {
172
+ if constexpr (std::is_same<T, half>::value) {
173
+ asm volatile(
174
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
175
+ "{%0, %1, %2, %3},"
176
+ "{%4, %5, %6, %7},"
177
+ "{%8, %9},"
178
+ "{%10, %11, %12, %13};\n"
179
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
180
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
181
+ "f"(C[2]), "f"(C[3]));
182
+ asm volatile(
183
+ "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
184
+ "{%0, %1, %2, %3},"
185
+ "{%4, %5, %6, %7},"
186
+ "{%8, %9},"
187
+ "{%10, %11, %12, %13};\n"
188
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
189
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
190
+ "f"(C[6]), "f"(C[7]));
191
+ } else {
192
+ asm volatile(
193
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
194
+ "{%0, %1, %2, %3},"
195
+ "{%4, %5, %6, %7},"
196
+ "{%8, %9},"
197
+ "{%10, %11, %12, %13};\n"
198
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
199
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
200
+ "f"(C[2]), "f"(C[3]));
201
+ asm volatile(
202
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
203
+ "{%0, %1, %2, %3},"
204
+ "{%4, %5, %6, %7},"
205
+ "{%8, %9},"
206
+ "{%10, %11, %12, %13};\n"
207
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
208
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
209
+ "f"(C[6]), "f"(C[7]));
210
+ }
211
+ }
212
+ #elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED)
213
+ static_assert(std::is_same<T, half>::value, "bf16 mma instruction is not supported on sm_75");
214
+ if constexpr (mma_mode == MMAMode::kInit) {
215
+ asm volatile(
216
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
217
+ "{%0, %1, %2, %3},"
218
+ "{%4, %5},"
219
+ "{%6},"
220
+ "{%7, %8, %9, %10};\n"
221
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
222
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
223
+ asm volatile(
224
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
225
+ "{%0, %1, %2, %3},"
226
+ "{%4, %5},"
227
+ "{%6},"
228
+ "{%7, %8, %9, %10};\n"
229
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
230
+ : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
231
+ asm volatile(
232
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
233
+ "{%0, %1, %2, %3},"
234
+ "{%4, %5},"
235
+ "{%6},"
236
+ "{%7, %8, %9, %10};\n"
237
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
238
+ : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
239
+ asm volatile(
240
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
241
+ "{%0, %1, %2, %3},"
242
+ "{%4, %5},"
243
+ "{%6},"
244
+ "{%7, %8, %9, %10};\n"
245
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
246
+ : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
247
+ } else {
248
+ asm volatile(
249
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
250
+ "{%0, %1, %2, %3},"
251
+ "{%4, %5},"
252
+ "{%6},"
253
+ "{%7, %8, %9, %10};\n"
254
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
255
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
256
+ asm volatile(
257
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
258
+ "{%0, %1, %2, %3},"
259
+ "{%4, %5},"
260
+ "{%6},"
261
+ "{%7, %8, %9, %10};\n"
262
+ : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
263
+ : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
264
+ asm volatile(
265
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
266
+ "{%0, %1, %2, %3},"
267
+ "{%4, %5},"
268
+ "{%6},"
269
+ "{%7, %8, %9, %10};\n"
270
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
271
+ : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7]));
272
+ asm volatile(
273
+ "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
274
+ "{%0, %1, %2, %3},"
275
+ "{%4, %5},"
276
+ "{%6},"
277
+ "{%7, %8, %9, %10};\n"
278
+ : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
279
+ : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7]));
280
+ }
281
+ #else
282
+ #error "Unsupported CUDA architecture for mma instruction"
283
+ #endif
284
+ }
285
+
286
+ /*!
287
+ * \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix
288
+ * multiplication, accumulated in f16.
289
+ * \tparam mma_mode whether we are initializing the accumulator or updating it
290
+ * \param C pointer to the accumulator
291
+ * \param A pointer to the fragment of matrix A
292
+ * \param B pointer to the fragment of matrix B
293
+ */
294
+ template <MMAMode mma_mode = MMAMode::kInplaceUpdate>
295
+ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C, uint32_t* A,
296
+ uint32_t* B) {
297
+ #if defined(FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED)
298
+ if constexpr (mma_mode == MMAMode::kInit) {
299
+ asm volatile(
300
+ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
301
+ "{%0, %1},"
302
+ "{%2, %3, %4, %5},"
303
+ "{%6, %7},"
304
+ "{%8, %9};\n"
305
+ : "=r"(C[0]), "=r"(C[1])
306
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0));
307
+ asm volatile(
308
+ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
309
+ "{%0, %1},"
310
+ "{%2, %3, %4, %5},"
311
+ "{%6, %7},"
312
+ "{%8, %9};\n"
313
+ : "=r"(C[2]), "=r"(C[3])
314
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0));
315
+ } else {
316
+ asm volatile(
317
+ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
318
+ "{%0, %1},"
319
+ "{%2, %3, %4, %5},"
320
+ "{%6, %7},"
321
+ "{%8, %9};\n"
322
+ : "=r"(C[0]), "=r"(C[1])
323
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]));
324
+ asm volatile(
325
+ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
326
+ "{%0, %1},"
327
+ "{%2, %3, %4, %5},"
328
+ "{%6, %7},"
329
+ "{%8, %9};\n"
330
+ : "=r"(C[2]), "=r"(C[3])
331
+ : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[2]), "r"(C[3]));
332
+ }
333
+ #elif defined(FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED)
334
+ if constexpr (mma_mode == MMAMode::kInit) {
335
+ asm volatile(
336
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
337
+ "{%0, %1},"
338
+ "{%2, %3},"
339
+ "{%4},"
340
+ "{%5, %6};\n"
341
+ : "=r"(C[0]), "=r"(C[1])
342
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(0), "r"(0));
343
+ asm volatile(
344
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
345
+ "{%0, %1},"
346
+ "{%2, %3},"
347
+ "{%4},"
348
+ "{%5, %6};\n"
349
+ : "=r"(C[0]), "=r"(C[1])
350
+ : "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(0), "r"(0));
351
+ asm volatile(
352
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
353
+ "{%0, %1},"
354
+ "{%2, %3},"
355
+ "{%4},"
356
+ "{%5, %6};\n"
357
+ : "=r"(C[2]), "=r"(C[3])
358
+ : "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(0), "r"(0));
359
+ asm volatile(
360
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
361
+ "{%0, %1},"
362
+ "{%2, %3},"
363
+ "{%4},"
364
+ "{%5, %6};\n"
365
+ : "=r"(C[2]), "=r"(C[3])
366
+ : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(0), "r"(0));
367
+ } else {
368
+ asm volatile(
369
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
370
+ "{%0, %1},"
371
+ "{%2, %3},"
372
+ "{%4},"
373
+ "{%5, %6};\n"
374
+ : "=r"(C[0]), "=r"(C[1])
375
+ : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
376
+ asm volatile(
377
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
378
+ "{%0, %1},"
379
+ "{%2, %3},"
380
+ "{%4},"
381
+ "{%5, %6};\n"
382
+ : "=r"(C[0]), "=r"(C[1])
383
+ : "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(C[0]), "r"(C[1]));
384
+ asm volatile(
385
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
386
+ "{%0, %1},"
387
+ "{%2, %3},"
388
+ "{%4},"
389
+ "{%5, %6};\n"
390
+ : "=r"(C[2]), "=r"(C[3])
391
+ : "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(C[2]), "r"(C[3]));
392
+ asm volatile(
393
+ "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
394
+ "{%0, %1},"
395
+ "{%2, %3},"
396
+ "{%4},"
397
+ "{%5, %6};\n"
398
+ : "=r"(C[2]), "=r"(C[3])
399
+ : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(C[2]), "r"(C[3]));
400
+ }
401
+ #else
402
+ #error "Unsupported CUDA architecture for mma instruction"
403
+ #endif
404
+ }
405
+
406
+ } // namespace mma
407
+
408
+ } // namespace flashinfer
409
+
410
+ #endif // FLASHINFER_MMA_CUH_
flashinfer/permuted_smem.cuh ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023 by FlashInfer team.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef FLASHINFER_PERMUTED_SMEM_CUH_
17
+ #define FLASHINFER_PERMUTED_SMEM_CUH_
18
+
19
+ #include <cuda_bf16.h>
20
+ #include <cuda_fp16.h>
21
+ #include <cuda_runtime.h>
22
+
23
+ #include <cuda/pipeline>
24
+
25
+ #include "cp_async.cuh"
26
+ #include "mma.cuh"
27
+
28
+ namespace flashinfer {
29
+
30
+ // Each cell is 4 bytes.
31
+ using cell_t = uint4;
32
+
33
+ /*!
34
+ * \brief Compute the number of elements that can be stored in a cell.
35
+ * \tparam T The data type of the elements.
36
+ */
37
+ template <typename T>
38
+ constexpr __host__ __device__ __forceinline__ uint32_t cell_capacity() {
39
+ return sizeof(cell_t) / sizeof(T);
40
+ }
41
+
42
+ /*!
43
+ * \brief The shared memory wrapper.
44
+ */
45
+ struct smem_t {
46
+ // The base pointer.
47
+ cell_t* base;
48
+ __device__ __forceinline__ smem_t() : base(nullptr) {}
49
+ template <typename T>
50
+ __device__ __forceinline__ smem_t(T* base) : base((cell_t*)base) {}
51
+
52
+ /*!
53
+ * \brief Compute the element offset given coordinates in a permuted shared memory.
54
+ * \tparam stride The stride (in terms of cells) in the permuted shared memory.
55
+ * \param i The row index.
56
+ * \param j The column index.
57
+ */
58
+ template <uint32_t stride>
59
+ static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) {
60
+ return (i / 2) * stride * 2 + (j / 4) * 8 + (i % 2) * 4 + ((j % 4) ^ ((i / 2) % 4));
61
+ }
62
+
63
+ __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, uint32_t* R) {
64
+ cell_t* smem_ptr = base + offset;
65
+ mma::ldmatrix_m8n8x4(R, smem_ptr);
66
+ }
67
+ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t offset, uint32_t* R) {
68
+ cell_t* smem_ptr = base + offset;
69
+ mma::stmatrix_m8n8x4(R, smem_ptr);
70
+ }
71
+ __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, uint32_t* R) {
72
+ cell_t* smem_ptr = base + offset;
73
+ mma::ldmatrix_m8n8x4_trans(R, smem_ptr);
74
+ }
75
+ template <cp_async::SharedMemFillMode fill_mode, typename T>
76
+ __device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr, bool predicate) {
77
+ cell_t* smem_ptr = base + offset;
78
+ cp_async::pred_load_128b<cp_async::PrefetchMode::kPrefetch, fill_mode>(
79
+ smem_ptr, reinterpret_cast<const cell_t*>(gptr), predicate);
80
+ }
81
+ template <typename T>
82
+ __device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr) {
83
+ cell_t* smem_ptr = base + offset;
84
+ cp_async::load_128b<cp_async::PrefetchMode::kPrefetch>(smem_ptr,
85
+ reinterpret_cast<const cell_t*>(gptr));
86
+ }
87
+ template <typename T>
88
+ __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) {
89
+ *reinterpret_cast<cell_t*>(gptr) = *(base + offset);
90
+ }
91
+ };
92
+
93
+ } // namespace flashinfer
94
+
95
+ #endif // FLASHINFER_PERMUTED_SMEM_CUH_
flashinfer/vec_dtypes.cuh ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023 by FlashInfer team.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+ #ifndef VEC_DTYPES_CUH_
17
+ #define VEC_DTYPES_CUH_
18
+
19
+ #include <cuda_bf16.h>
20
+ #include <cuda_fp16.h>
21
+ #ifdef FLASHINFER_ENABLE_FP8
22
+ #include <cuda_fp8.h>
23
+ #endif
24
+ #include <cuda_runtime.h>
25
+
26
+ #include <type_traits>
27
+
28
+ namespace flashinfer {
29
+
30
+ #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ __host__
31
+
32
+ template <typename float_t, size_t vec_size>
33
+ struct vec_t {
34
+ FLASHINFER_INLINE float_t& operator[](size_t i);
35
+ FLASHINFER_INLINE const float_t& operator[](size_t i) const;
36
+ FLASHINFER_INLINE void fill(float_t val);
37
+ FLASHINFER_INLINE void load(const float_t* ptr);
38
+ FLASHINFER_INLINE void store(float_t* ptr) const;
39
+ template <typename T>
40
+ FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src);
41
+ template <typename T>
42
+ FLASHINFER_INLINE void cast_load(const T* ptr);
43
+ template <typename T>
44
+ FLASHINFER_INLINE void cast_store(T* ptr) const;
45
+ FLASHINFER_INLINE static void memcpy(float_t* dst, const float_t* src);
46
+ FLASHINFER_INLINE float_t* ptr();
47
+ };
48
+
49
+ template <typename src_float_t, typename tgt_float_t, size_t vec_size>
50
+ FLASHINFER_INLINE void cast_from_impl(vec_t<tgt_float_t, vec_size>& dst,
51
+ const vec_t<src_float_t, vec_size>& src) {
52
+ #pragma unroll
53
+ for (size_t i = 0; i < vec_size; ++i) {
54
+ dst[i] = tgt_float_t(src[i]);
55
+ }
56
+ }
57
+
58
+ template <typename src_float_t, typename tgt_float_t, size_t vec_size>
59
+ FLASHINFER_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
60
+ const src_float_t* src_ptr) {
61
+ if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
62
+ dst.load(src_ptr);
63
+ } else {
64
+ vec_t<src_float_t, vec_size> tmp;
65
+ tmp.load(src_ptr);
66
+ dst.cast_from(tmp);
67
+ }
68
+ }
69
+
70
+ template <typename src_float_t, typename tgt_float_t, size_t vec_size>
71
+ FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr,
72
+ const vec_t<src_float_t, vec_size>& src) {
73
+ if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
74
+ src.store(dst_ptr);
75
+ } else {
76
+ vec_t<tgt_float_t, vec_size> tmp;
77
+ tmp.cast_from(src);
78
+ tmp.store(dst_ptr);
79
+ }
80
+ }
81
+
82
+ #ifdef FLASHINFER_ENABLE_FP8
83
+ /******************* vec_t<__nv_fp8_e4m3> *******************/
84
+
85
+ // __nv_fp8_e4m3 x 1
86
+ template <>
87
+ struct vec_t<__nv_fp8_e4m3, 1> {
88
+ __nv_fp8_e4m3 data;
89
+
90
+ FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
91
+ FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
92
+ return ((const __nv_fp8_e4m3*)(&data))[i];
93
+ }
94
+ FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
95
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
96
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
97
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
98
+ template <typename T>
99
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
100
+ cast_from_impl(*this, src);
101
+ }
102
+ template <typename T>
103
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
104
+ cast_load_impl(*this, ptr);
105
+ }
106
+ template <typename T>
107
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
108
+ cast_store_impl(ptr, *this);
109
+ }
110
+
111
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
112
+ };
113
+
114
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { data = val; }
115
+
116
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) { data = *ptr; }
117
+
118
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const { *ptr = data; }
119
+
120
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(__nv_fp8_e4m3* dst,
121
+ const __nv_fp8_e4m3* src) {
122
+ *dst = *src;
123
+ }
124
+
125
+ // __nv_fp8_e4m3 x 2
126
+ template <>
127
+ struct vec_t<__nv_fp8_e4m3, 2> {
128
+ __nv_fp8x2_e4m3 data;
129
+
130
+ FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
131
+ FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
132
+ return ((const __nv_fp8_e4m3*)(&data))[i];
133
+ }
134
+ FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
135
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
136
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
137
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
138
+ template <typename T>
139
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
140
+ cast_from_impl(*this, src);
141
+ }
142
+ template <typename T>
143
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
144
+ cast_load_impl(*this, ptr);
145
+ }
146
+ template <typename T>
147
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
148
+ cast_store_impl(ptr, *this);
149
+ }
150
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
151
+ };
152
+
153
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
154
+ data.__x = (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
155
+ }
156
+
157
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) {
158
+ data = *((__nv_fp8x2_e4m3*)ptr);
159
+ }
160
+
161
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const {
162
+ *((__nv_fp8x2_e4m3*)ptr) = data;
163
+ }
164
+
165
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(__nv_fp8_e4m3* dst,
166
+ const __nv_fp8_e4m3* src) {
167
+ *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src);
168
+ }
169
+
170
+ // __nv_fp8_e4m3 x 4
171
+
172
+ template <>
173
+ struct vec_t<__nv_fp8_e4m3, 4> {
174
+ __nv_fp8x4_e4m3 data;
175
+
176
+ FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
177
+ FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
178
+ return ((const __nv_fp8_e4m3*)(&data))[i];
179
+ }
180
+ FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
181
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
182
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
183
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
184
+ template <typename T>
185
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
186
+ cast_from_impl(*this, src);
187
+ }
188
+ template <typename T>
189
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
190
+ cast_load_impl(*this, ptr);
191
+ }
192
+ template <typename T>
193
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
194
+ cast_store_impl(ptr, *this);
195
+ }
196
+
197
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
198
+ };
199
+
200
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
201
+ data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
202
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
203
+ }
204
+
205
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) {
206
+ data = *((__nv_fp8x4_e4m3*)ptr);
207
+ }
208
+
209
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const {
210
+ *((__nv_fp8x4_e4m3*)ptr) = data;
211
+ }
212
+
213
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(__nv_fp8_e4m3* dst,
214
+ const __nv_fp8_e4m3* src) {
215
+ *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src);
216
+ }
217
+
218
+ // __nv_fp8_e4m3 x 8
219
+
220
+ template <>
221
+ struct vec_t<__nv_fp8_e4m3, 8> {
222
+ uint2 data;
223
+
224
+ FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
225
+ FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
226
+ return ((const __nv_fp8_e4m3*)(&data))[i];
227
+ }
228
+ FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
229
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
230
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
231
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
232
+ template <typename T>
233
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 8>& src) {
234
+ cast_from_impl(*this, src);
235
+ }
236
+ template <typename T>
237
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
238
+ cast_load_impl(*this, ptr);
239
+ }
240
+ template <typename T>
241
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
242
+ cast_store_impl(ptr, *this);
243
+ }
244
+
245
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
246
+ };
247
+
248
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
249
+ ((__nv_fp8x4_e4m3*)(&data.x))->__x =
250
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
251
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
252
+ ((__nv_fp8x4_e4m3*)(&data.y))->__x =
253
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
254
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
255
+ }
256
+
257
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) {
258
+ data = *((uint2*)ptr);
259
+ }
260
+
261
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const {
262
+ *((uint2*)ptr) = data;
263
+ }
264
+
265
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(__nv_fp8_e4m3* dst,
266
+ const __nv_fp8_e4m3* src) {
267
+ *((uint2*)dst) = *((uint2*)src);
268
+ }
269
+
270
+ // __nv_fp8_e4m3 x 16 or more
271
+ template <size_t vec_size>
272
+ struct vec_t<__nv_fp8_e4m3, vec_size> {
273
+ uint4 data[vec_size / 16];
274
+
275
+ FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)data)[i]; }
276
+ FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
277
+ return ((const __nv_fp8_e4m3*)data)[i];
278
+ }
279
+ FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
280
+ FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) {
281
+ #pragma unroll
282
+ for (size_t i = 0; i < vec_size / 16; ++i) {
283
+ ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x =
284
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
285
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
286
+ ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x =
287
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
288
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
289
+ ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x =
290
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
291
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
292
+ ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x =
293
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
294
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
295
+ }
296
+ }
297
+ FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr) {
298
+ #pragma unroll
299
+ for (size_t i = 0; i < vec_size / 16; ++i) {
300
+ data[i] = ((uint4*)ptr)[i];
301
+ }
302
+ }
303
+ FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const {
304
+ #pragma unroll
305
+ for (size_t i = 0; i < vec_size / 16; ++i) {
306
+ ((uint4*)ptr)[i] = data[i];
307
+ }
308
+ }
309
+ template <typename T>
310
+ FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
311
+ cast_from_impl(*this, src);
312
+ }
313
+ template <typename T>
314
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
315
+ cast_load_impl(*this, ptr);
316
+ }
317
+ template <typename T>
318
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
319
+ cast_store_impl(ptr, *this);
320
+ }
321
+
322
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
323
+ #pragma unroll
324
+ for (size_t i = 0; i < vec_size / 16; ++i) {
325
+ ((uint4*)dst)[i] = ((uint4*)src)[i];
326
+ }
327
+ }
328
+ };
329
+
330
+ /******************* vec_t<__nv_fp8_e5m2> *******************/
331
+
332
+ // __nv_fp8_e5m2 x 1
333
+ template <>
334
+ struct vec_t<__nv_fp8_e5m2, 1> {
335
+ __nv_fp8_e5m2 data;
336
+
337
+ FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
338
+ FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
339
+ return ((const __nv_fp8_e5m2*)(&data))[i];
340
+ }
341
+ FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
342
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
343
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
344
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
345
+ template <typename T>
346
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
347
+ cast_from_impl(*this, src);
348
+ }
349
+ template <typename T>
350
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
351
+ cast_load_impl(*this, ptr);
352
+ }
353
+ template <typename T>
354
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
355
+ cast_store_impl(ptr, *this);
356
+ }
357
+
358
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
359
+ };
360
+
361
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { data = val; }
362
+
363
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) { data = *ptr; }
364
+
365
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const { *ptr = data; }
366
+
367
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(__nv_fp8_e5m2* dst,
368
+ const __nv_fp8_e5m2* src) {
369
+ *dst = *src;
370
+ }
371
+
372
+ // __nv_fp8_e5m2 x 2
373
+ template <>
374
+ struct vec_t<__nv_fp8_e5m2, 2> {
375
+ __nv_fp8x2_e5m2 data;
376
+
377
+ FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
378
+ FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
379
+ return ((const __nv_fp8_e5m2*)(&data))[i];
380
+ }
381
+ FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
382
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
383
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
384
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
385
+ template <typename T>
386
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
387
+ cast_from_impl(*this, src);
388
+ }
389
+ template <typename T>
390
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
391
+ cast_load_impl(*this, ptr);
392
+ }
393
+ template <typename T>
394
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
395
+ cast_store_impl(ptr, *this);
396
+ }
397
+
398
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
399
+ };
400
+
401
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
402
+ data.__x = (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
403
+ }
404
+
405
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) {
406
+ data = *((__nv_fp8x2_e5m2*)ptr);
407
+ }
408
+
409
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const {
410
+ *((__nv_fp8x2_e5m2*)ptr) = data;
411
+ }
412
+
413
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(__nv_fp8_e5m2* dst,
414
+ const __nv_fp8_e5m2* src) {
415
+ *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src);
416
+ }
417
+
418
+ // __nv_fp8_e5m2 x 4
419
+
420
+ template <>
421
+ struct vec_t<__nv_fp8_e5m2, 4> {
422
+ __nv_fp8x4_e5m2 data;
423
+
424
+ FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
425
+ FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
426
+ return ((const __nv_fp8_e5m2*)(&data))[i];
427
+ }
428
+ FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
429
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
430
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
431
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
432
+ template <typename T>
433
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
434
+ cast_from_impl(*this, src);
435
+ }
436
+ template <typename T>
437
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
438
+ cast_load_impl(*this, ptr);
439
+ }
440
+ template <typename T>
441
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
442
+ cast_store_impl(ptr, *this);
443
+ }
444
+
445
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
446
+ };
447
+
448
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
449
+ data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
450
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
451
+ }
452
+
453
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) {
454
+ data = *((__nv_fp8x4_e5m2*)ptr);
455
+ }
456
+
457
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const {
458
+ *((__nv_fp8x4_e5m2*)ptr) = data;
459
+ }
460
+
461
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(__nv_fp8_e5m2* dst,
462
+ const __nv_fp8_e5m2* src) {
463
+ *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src);
464
+ }
465
+
466
+ // __nv_fp8_e5m2 x 8
467
+
468
+ template <>
469
+ struct vec_t<__nv_fp8_e5m2, 8> {
470
+ uint2 data;
471
+
472
+ FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
473
+ FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
474
+ return ((const __nv_fp8_e5m2*)(&data))[i];
475
+ }
476
+ FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
477
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
478
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
479
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
480
+ template <typename T>
481
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 8>& src) {
482
+ cast_from_impl(*this, src);
483
+ }
484
+ template <typename T>
485
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
486
+ cast_load_impl(*this, ptr);
487
+ }
488
+ template <typename T>
489
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
490
+ cast_store_impl(ptr, *this);
491
+ }
492
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
493
+ };
494
+
495
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
496
+ ((__nv_fp8x4_e5m2*)(&data.x))->__x =
497
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
498
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
499
+ ((__nv_fp8x4_e5m2*)(&data.y))->__x =
500
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
501
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
502
+ }
503
+
504
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) {
505
+ data = *((uint2*)ptr);
506
+ }
507
+
508
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const {
509
+ *((uint2*)ptr) = data;
510
+ }
511
+
512
+ FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(__nv_fp8_e5m2* dst,
513
+ const __nv_fp8_e5m2* src) {
514
+ *((uint2*)dst) = *((uint2*)src);
515
+ }
516
+
517
+ // __nv_fp8_e5m2 x 16 or more
518
+
519
+ template <size_t vec_size>
520
+ struct vec_t<__nv_fp8_e5m2, vec_size> {
521
+ uint4 data[vec_size / 16];
522
+
523
+ FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)data)[i]; }
524
+ FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
525
+ return ((const __nv_fp8_e5m2*)data)[i];
526
+ }
527
+ FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
528
+ FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) {
529
+ #pragma unroll
530
+ for (size_t i = 0; i < vec_size / 16; ++i) {
531
+ ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x =
532
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
533
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
534
+ ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x =
535
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
536
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
537
+ ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x =
538
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
539
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
540
+ ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x =
541
+ (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
542
+ (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
543
+ }
544
+ }
545
+ FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr) {
546
+ #pragma unroll
547
+ for (size_t i = 0; i < vec_size / 16; ++i) {
548
+ data[i] = ((uint4*)ptr)[i];
549
+ }
550
+ }
551
+ FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const {
552
+ #pragma unroll
553
+ for (size_t i = 0; i < vec_size / 16; ++i) {
554
+ ((uint4*)ptr)[i] = data[i];
555
+ }
556
+ }
557
+ template <typename T>
558
+ FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
559
+ cast_from_impl(*this, src);
560
+ }
561
+ template <typename T>
562
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
563
+ cast_load_impl(*this, ptr);
564
+ }
565
+ template <typename T>
566
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
567
+ cast_store_impl(ptr, *this);
568
+ }
569
+ FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
570
+ #pragma unroll
571
+ for (size_t i = 0; i < vec_size / 16; ++i) {
572
+ ((uint4*)dst)[i] = ((uint4*)src)[i];
573
+ }
574
+ }
575
+ };
576
+ #endif
577
+
578
+ /******************* vec_t<half> *******************/
579
+
580
+ // half x 1
581
+ template <>
582
+ struct vec_t<half, 1> {
583
+ half data;
584
+
585
+ FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
586
+ FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; }
587
+ FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
588
+ FLASHINFER_INLINE void fill(half val);
589
+ FLASHINFER_INLINE void load(const half* ptr);
590
+ FLASHINFER_INLINE void store(half* ptr) const;
591
+ template <typename T>
592
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
593
+ cast_from_impl(*this, src);
594
+ }
595
+ template <typename T>
596
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
597
+ cast_load_impl(*this, ptr);
598
+ }
599
+ template <typename T>
600
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
601
+ cast_store_impl(ptr, *this);
602
+ }
603
+
604
+ FLASHINFER_INLINE static void memcpy(half* dst, const half* src);
605
+ };
606
+
607
+ FLASHINFER_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
608
+
609
+ FLASHINFER_INLINE void vec_t<half, 1>::load(const half* ptr) { data = *ptr; }
610
+
611
+ FLASHINFER_INLINE void vec_t<half, 1>::store(half* ptr) const { *ptr = data; }
612
+
613
+ FLASHINFER_INLINE void vec_t<half, 1>::memcpy(half* dst, const half* src) { *dst = *src; }
614
+
615
+ // half x 2
616
+ template <>
617
+ struct vec_t<half, 2> {
618
+ half2 data;
619
+
620
+ FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
621
+ FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; }
622
+ FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
623
+ FLASHINFER_INLINE void fill(half val);
624
+ FLASHINFER_INLINE void load(const half* ptr);
625
+ FLASHINFER_INLINE void store(half* ptr) const;
626
+ template <typename T>
627
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
628
+ cast_from_impl(*this, src);
629
+ }
630
+ template <typename T>
631
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
632
+ cast_load_impl(*this, ptr);
633
+ }
634
+ template <typename T>
635
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
636
+ cast_store_impl(ptr, *this);
637
+ }
638
+
639
+ FLASHINFER_INLINE static void memcpy(half* dst, const half* src);
640
+ };
641
+
642
+ FLASHINFER_INLINE void vec_t<half, 2>::fill(half val) { data = make_half2(val, val); }
643
+
644
+ FLASHINFER_INLINE void vec_t<half, 2>::load(const half* ptr) { data = *((half2*)ptr); }
645
+
646
+ FLASHINFER_INLINE void vec_t<half, 2>::store(half* ptr) const { *((half2*)ptr) = data; }
647
+
648
+ FLASHINFER_INLINE void vec_t<half, 2>::memcpy(half* dst, const half* src) {
649
+ *((half2*)dst) = *((half2*)src);
650
+ }
651
+
652
+ // half x 4
653
+
654
+ template <>
655
+ struct vec_t<half, 4> {
656
+ uint2 data;
657
+
658
+ FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
659
+ FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; }
660
+ FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
661
+ FLASHINFER_INLINE void fill(half val);
662
+ FLASHINFER_INLINE void load(const half* ptr);
663
+ FLASHINFER_INLINE void store(half* ptr) const;
664
+ template <typename T>
665
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
666
+ cast_from_impl(*this, src);
667
+ }
668
+ template <typename T>
669
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
670
+ cast_load_impl(*this, ptr);
671
+ }
672
+ template <typename T>
673
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
674
+ cast_store_impl(ptr, *this);
675
+ }
676
+ FLASHINFER_INLINE static void memcpy(half* dst, const half* src);
677
+ };
678
+
679
+ FLASHINFER_INLINE void vec_t<half, 4>::fill(half val) {
680
+ *(half2*)(&data.x) = make_half2(val, val);
681
+ *(half2*)(&data.y) = make_half2(val, val);
682
+ }
683
+
684
+ FLASHINFER_INLINE void vec_t<half, 4>::load(const half* ptr) { data = *((uint2*)ptr); }
685
+
686
+ FLASHINFER_INLINE void vec_t<half, 4>::store(half* ptr) const { *((uint2*)ptr) = data; }
687
+
688
+ FLASHINFER_INLINE void vec_t<half, 4>::memcpy(half* dst, const half* src) {
689
+ *((uint2*)dst) = *((uint2*)src);
690
+ }
691
+
692
+ // half x 8 or more
693
+
694
+ template <size_t vec_size>
695
+ struct vec_t<half, vec_size> {
696
+ uint4 data[vec_size / 8];
697
+ FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)data)[i]; }
698
+ FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)data)[i]; }
699
+ FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
700
+ FLASHINFER_INLINE void fill(half val) {
701
+ #pragma unroll
702
+ for (size_t i = 0; i < vec_size / 8; ++i) {
703
+ *(half2*)(&(data[i].x)) = make_half2(val, val);
704
+ *(half2*)(&(data[i].y)) = make_half2(val, val);
705
+ *(half2*)(&(data[i].z)) = make_half2(val, val);
706
+ *(half2*)(&(data[i].w)) = make_half2(val, val);
707
+ }
708
+ }
709
+ FLASHINFER_INLINE void load(const half* ptr) {
710
+ #pragma unroll
711
+ for (size_t i = 0; i < vec_size / 8; ++i) {
712
+ data[i] = ((uint4*)ptr)[i];
713
+ }
714
+ }
715
+ FLASHINFER_INLINE void store(half* ptr) const {
716
+ #pragma unroll
717
+ for (size_t i = 0; i < vec_size / 8; ++i) {
718
+ ((uint4*)ptr)[i] = data[i];
719
+ }
720
+ }
721
+ template <typename T>
722
+ FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
723
+ cast_from_impl(*this, src);
724
+ }
725
+ template <typename T>
726
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
727
+ cast_load_impl(*this, ptr);
728
+ }
729
+ template <typename T>
730
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
731
+ cast_store_impl(ptr, *this);
732
+ }
733
+ FLASHINFER_INLINE static void memcpy(half* dst, const half* src) {
734
+ #pragma unroll
735
+ for (size_t i = 0; i < vec_size / 8; ++i) {
736
+ ((uint4*)dst)[i] = ((uint4*)src)[i];
737
+ }
738
+ }
739
+ };
740
+
741
+ /******************* vec_t<nv_bfloat16> *******************/
742
+
743
+ // nv_bfloat16 x 1
744
+ template <>
745
+ struct vec_t<nv_bfloat16, 1> {
746
+ nv_bfloat16 data;
747
+ FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; }
748
+ FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
749
+ return ((const nv_bfloat16*)(&data))[i];
750
+ }
751
+ FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
752
+ FLASHINFER_INLINE void fill(nv_bfloat16 val);
753
+ FLASHINFER_INLINE void load(const nv_bfloat16* ptr);
754
+ FLASHINFER_INLINE void store(nv_bfloat16* ptr) const;
755
+ template <typename T>
756
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
757
+ cast_from_impl(*this, src);
758
+ }
759
+ template <typename T>
760
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
761
+ cast_load_impl(*this, ptr);
762
+ }
763
+ template <typename T>
764
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
765
+ cast_store_impl(ptr, *this);
766
+ }
767
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
768
+ };
769
+
770
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) { data = val; }
771
+
772
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) { data = *ptr; }
773
+
774
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const { *ptr = data; }
775
+
776
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
777
+ *dst = *src;
778
+ }
779
+
780
+ // nv_bfloat16 x 2
781
+ template <>
782
+ struct vec_t<nv_bfloat16, 2> {
783
+ nv_bfloat162 data;
784
+
785
+ FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; }
786
+ FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
787
+ return ((const nv_bfloat16*)(&data))[i];
788
+ }
789
+ FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
790
+ FLASHINFER_INLINE void fill(nv_bfloat16 val);
791
+ FLASHINFER_INLINE void load(const nv_bfloat16* ptr);
792
+ FLASHINFER_INLINE void store(nv_bfloat16* ptr) const;
793
+ template <typename T>
794
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
795
+ cast_from_impl(*this, src);
796
+ }
797
+ template <typename T>
798
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
799
+ cast_load_impl(*this, ptr);
800
+ }
801
+ template <typename T>
802
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
803
+ cast_store_impl(ptr, *this);
804
+ }
805
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
806
+ };
807
+
808
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
809
+ data = make_bfloat162(val, val);
810
+ }
811
+
812
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
813
+ data = *((nv_bfloat162*)ptr);
814
+ }
815
+
816
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
817
+ *((nv_bfloat162*)ptr) = data;
818
+ }
819
+
820
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
821
+ *((nv_bfloat162*)dst) = *((nv_bfloat162*)src);
822
+ }
823
+
824
+ // nv_bfloat16 x 4
825
+
826
+ template <>
827
+ struct vec_t<nv_bfloat16, 4> {
828
+ uint2 data;
829
+
830
+ FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; }
831
+ FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
832
+ return ((const nv_bfloat16*)(&data))[i];
833
+ }
834
+ FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
835
+ FLASHINFER_INLINE void fill(nv_bfloat16 val);
836
+ FLASHINFER_INLINE void load(const nv_bfloat16* ptr);
837
+ FLASHINFER_INLINE void store(nv_bfloat16* ptr) const;
838
+ template <typename T>
839
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
840
+ cast_from_impl(*this, src);
841
+ }
842
+ template <typename T>
843
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
844
+ cast_load_impl(*this, ptr);
845
+ }
846
+ template <typename T>
847
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
848
+ cast_store_impl(ptr, *this);
849
+ }
850
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
851
+ };
852
+
853
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
854
+ *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val);
855
+ *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val);
856
+ }
857
+
858
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
859
+ data = *((uint2*)ptr);
860
+ }
861
+
862
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
863
+ *((uint2*)ptr) = data;
864
+ }
865
+
866
+ FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
867
+ *((uint2*)dst) = *((uint2*)src);
868
+ }
869
+
870
+ // nv_bfloat16 x 8 or more
871
+
872
+ template <size_t vec_size>
873
+ struct vec_t<nv_bfloat16, vec_size> {
874
+ uint4 data[vec_size / 8];
875
+
876
+ FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)data)[i]; }
877
+ FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
878
+ return ((const nv_bfloat16*)data)[i];
879
+ }
880
+ FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
881
+ FLASHINFER_INLINE void fill(nv_bfloat16 val) {
882
+ #pragma unoll
883
+ for (size_t i = 0; i < vec_size / 8; ++i) {
884
+ *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val);
885
+ *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val);
886
+ *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val);
887
+ *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val);
888
+ }
889
+ }
890
+ FLASHINFER_INLINE void load(const nv_bfloat16* ptr) {
891
+ #pragma unoll
892
+ for (size_t i = 0; i < vec_size / 8; ++i) {
893
+ data[i] = ((uint4*)ptr)[i];
894
+ }
895
+ }
896
+ FLASHINFER_INLINE void store(nv_bfloat16* ptr) const {
897
+ #pragma unoll
898
+ for (size_t i = 0; i < vec_size / 8; ++i) {
899
+ ((uint4*)ptr)[i] = data[i];
900
+ }
901
+ }
902
+ template <typename T>
903
+ FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
904
+ cast_from_impl(*this, src);
905
+ }
906
+ template <typename T>
907
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
908
+ cast_load_impl(*this, ptr);
909
+ }
910
+ template <typename T>
911
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
912
+ cast_store_impl(ptr, *this);
913
+ }
914
+ FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
915
+ #pragma unoll
916
+ for (size_t i = 0; i < vec_size / 8; ++i) {
917
+ ((uint4*)dst)[i] = ((uint4*)src)[i];
918
+ }
919
+ }
920
+ };
921
+
922
+ /******************* vec_t<float> *******************/
923
+
924
+ // float x 1
925
+
926
+ template <>
927
+ struct vec_t<float, 1> {
928
+ float data;
929
+
930
+ FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
931
+ FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; }
932
+ FLASHINFER_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
933
+ FLASHINFER_INLINE void fill(float val);
934
+ FLASHINFER_INLINE void load(const float* ptr);
935
+ FLASHINFER_INLINE void store(float* ptr) const;
936
+ template <typename T>
937
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
938
+ cast_from_impl(*this, src);
939
+ }
940
+ template <typename T>
941
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
942
+ cast_load_impl(*this, ptr);
943
+ }
944
+ template <typename T>
945
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
946
+ cast_store_impl(ptr, *this);
947
+ }
948
+ FLASHINFER_INLINE static void memcpy(float* dst, const float* src);
949
+ };
950
+
951
+ FLASHINFER_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
952
+
953
+ FLASHINFER_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
954
+
955
+ FLASHINFER_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
956
+
957
+ FLASHINFER_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) { *dst = *src; }
958
+
959
+ // float x 2
960
+
961
+ template <>
962
+ struct vec_t<float, 2> {
963
+ float2 data;
964
+
965
+ FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
966
+ FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; }
967
+ FLASHINFER_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
968
+ FLASHINFER_INLINE void fill(float val);
969
+ FLASHINFER_INLINE void load(const float* ptr);
970
+ FLASHINFER_INLINE void store(float* ptr) const;
971
+ template <typename T>
972
+ FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
973
+ cast_from_impl(*this, src);
974
+ }
975
+ template <typename T>
976
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
977
+ cast_load_impl(*this, ptr);
978
+ }
979
+ template <typename T>
980
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
981
+ cast_store_impl(ptr, *this);
982
+ }
983
+ FLASHINFER_INLINE static void memcpy(float* dst, const float* src);
984
+ };
985
+
986
+ FLASHINFER_INLINE void vec_t<float, 2>::fill(float val) { data = make_float2(val, val); }
987
+
988
+ FLASHINFER_INLINE void vec_t<float, 2>::load(const float* ptr) { data = *((float2*)ptr); }
989
+
990
+ FLASHINFER_INLINE void vec_t<float, 2>::store(float* ptr) const { *((float2*)ptr) = data; }
991
+
992
+ FLASHINFER_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
993
+ *((float2*)dst) = *((float2*)src);
994
+ }
995
+
996
+ // float x 4 or more
997
+ template <size_t vec_size>
998
+ struct vec_t<float, vec_size> {
999
+ float4 data[vec_size / 4];
1000
+
1001
+ FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
1002
+ FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; }
1003
+ FLASHINFER_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
1004
+ FLASHINFER_INLINE void fill(float val) {
1005
+ #pragma unroll
1006
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1007
+ data[i] = make_float4(val, val, val, val);
1008
+ }
1009
+ }
1010
+ FLASHINFER_INLINE void load(const float* ptr) {
1011
+ #pragma unroll
1012
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1013
+ data[i] = ((float4*)ptr)[i];
1014
+ }
1015
+ }
1016
+ FLASHINFER_INLINE void store(float* ptr) const {
1017
+ #pragma unroll
1018
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1019
+ ((float4*)ptr)[i] = data[i];
1020
+ }
1021
+ }
1022
+ template <typename T>
1023
+ FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
1024
+ cast_from_impl(*this, src);
1025
+ }
1026
+ template <typename T>
1027
+ FLASHINFER_INLINE void cast_load(const T* ptr) {
1028
+ cast_load_impl(*this, ptr);
1029
+ }
1030
+ template <typename T>
1031
+ FLASHINFER_INLINE void cast_store(T* ptr) const {
1032
+ cast_store_impl(ptr, *this);
1033
+ }
1034
+ FLASHINFER_INLINE static void memcpy(float* dst, const float* src) {
1035
+ #pragma unroll
1036
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1037
+ ((float4*)dst)[i] = ((float4*)src)[i];
1038
+ }
1039
+ }
1040
+ };
1041
+
1042
+ /******************* vec_t type cast *******************/
1043
+
1044
+ template <typename dst_t, typename src_t, size_t vec_size>
1045
+ FLASHINFER_INLINE void vec_cast(dst_t* dst, const src_t* src) {
1046
+ #pragma unroll
1047
+ for (size_t i = 0; i < vec_size; ++i) {
1048
+ dst[i] = src[i];
1049
+ }
1050
+ }
1051
+
1052
+ template <size_t vec_size>
1053
+ FLASHINFER_INLINE void vec_cast<float, half>(float* dst, const half* src) {
1054
+ #pragma unroll
1055
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1056
+ ((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
1057
+ }
1058
+ }
1059
+
1060
+ template <size_t vec_size>
1061
+ FLASHINFER_INLINE void vec_cast<half, float>(half* dst, const float* src) {
1062
+ #pragma unroll
1063
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1064
+ ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
1065
+ }
1066
+ }
1067
+
1068
+ template <size_t vec_size>
1069
+ FLASHINFER_INLINE void vec_cast<float, nv_bfloat16>(float* dst, const nv_bfloat16* src) {
1070
+ #pragma unroll
1071
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1072
+ ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
1073
+ }
1074
+ }
1075
+
1076
+ template <size_t vec_size>
1077
+ FLASHINFER_INLINE void vec_cast<nv_bfloat16, float>(nv_bfloat16* dst, const float* src) {
1078
+ #pragma unroll
1079
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1080
+ ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
1081
+ }
1082
+ }
1083
+
1084
+ template <size_t vec_size>
1085
+ FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
1086
+ const vec_t<half, vec_size>& src) {
1087
+ if constexpr (vec_size == 1) {
1088
+ dst.data = float(src.data);
1089
+ } else {
1090
+ #pragma unroll
1091
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1092
+ ((float2*)(&dst.data))[i] = __half22float2(((half2*)(&src.data))[i]);
1093
+ }
1094
+ }
1095
+ }
1096
+
1097
+ template <size_t vec_size>
1098
+ FLASHINFER_INLINE void cast_from_impl(vec_t<half, vec_size>& dst,
1099
+ const vec_t<float, vec_size>& src) {
1100
+ if constexpr (vec_size == 1) {
1101
+ dst.data = half(src.data);
1102
+ } else {
1103
+ #pragma unroll
1104
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1105
+ ((half2*)(&dst.data))[i] = __float22half2_rn(((float2*)(&src.data))[i]);
1106
+ }
1107
+ }
1108
+ }
1109
+
1110
+ template <size_t vec_size>
1111
+ FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
1112
+ const vec_t<nv_bfloat16, vec_size>& src) {
1113
+ if constexpr (vec_size == 1) {
1114
+ dst.data = float(src.data);
1115
+ } else {
1116
+ #pragma unroll
1117
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1118
+ ((float2*)(&dst.data))[i] = __bfloat1622float2(((nv_bfloat162*)(&src.data))[i]);
1119
+ }
1120
+ }
1121
+ }
1122
+
1123
+ template <size_t vec_size>
1124
+ FLASHINFER_INLINE void cast_from_impl(vec_t<nv_bfloat16, vec_size>& dst,
1125
+ const vec_t<float, vec_size>& src) {
1126
+ if constexpr (vec_size == 1) {
1127
+ dst.data = nv_bfloat16(src.data);
1128
+ } else {
1129
+ #pragma unroll
1130
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1131
+ ((nv_bfloat162*)(&dst.data))[i] = __float22bfloat162_rn(((float2*)(&src.data))[i]);
1132
+ }
1133
+ }
1134
+ }
1135
+
1136
+ #ifdef FLASHINFER_ENABLE_FP8
1137
+
1138
+ template <size_t vec_size>
1139
+ FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
1140
+ const vec_t<__nv_fp8_e4m3, vec_size>& src) {
1141
+ if constexpr (vec_size == 1) {
1142
+ dst.data = float(src.data);
1143
+ } else if constexpr (vec_size == 2) {
1144
+ *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e4m3*)(&src.data));
1145
+ } else {
1146
+ #pragma unroll
1147
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1148
+ ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3*)(&src.data))[i]);
1149
+ }
1150
+ }
1151
+ }
1152
+
1153
+ template <size_t vec_size>
1154
+ FLASHINFER_INLINE void cast_from_impl(vec_t<half, vec_size>& dst,
1155
+ const vec_t<__nv_fp8_e4m3, vec_size>& src) {
1156
+ if constexpr (vec_size == 1) {
1157
+ dst.data = float(src.data);
1158
+ } else {
1159
+ #pragma unroll
1160
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1161
+ ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3*)(&src.data))[i]);
1162
+ }
1163
+ }
1164
+ }
1165
+
1166
+ template <size_t vec_size>
1167
+ FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst,
1168
+ const vec_t<float, vec_size>& src) {
1169
+ if constexpr (vec_size == 1) {
1170
+ dst.data = __nv_fp8_e4m3(src.data);
1171
+ } else if constexpr (vec_size == 2) {
1172
+ *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(float2*)(&src.data));
1173
+ } else {
1174
+ #pragma unroll
1175
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1176
+ ((__nv_fp8x4_e4m3*)(&dst.data))[i] = __nv_fp8x4_e4m3(((float4*)(&src.data))[i]);
1177
+ }
1178
+ }
1179
+ }
1180
+
1181
+ template <size_t vec_size>
1182
+ FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst,
1183
+ const vec_t<half, vec_size>& src) {
1184
+ if constexpr (vec_size == 1) {
1185
+ dst.data = __nv_fp8_e4m3(src.data);
1186
+ } else if constexpr (vec_size == 2) {
1187
+ *(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(half2*)(&src.data));
1188
+ } else {
1189
+ #pragma unroll
1190
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1191
+ // NOTE(Zihao): need to double check if we properly handle flo and fhi
1192
+ ((__nv_fp8x4_e4m3*)(&dst.data))[i] =
1193
+ __nv_fp8x4_e4m3(((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]);
1194
+ }
1195
+ }
1196
+ }
1197
+
1198
+ template <size_t vec_size>
1199
+ FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
1200
+ const vec_t<__nv_fp8_e5m2, vec_size>& src) {
1201
+ if constexpr (vec_size == 1) {
1202
+ dst.data = float(src.data);
1203
+ } else if constexpr (vec_size == 2) {
1204
+ *(float2*)(&dst.data) = float2(*(__nv_fp8x2_e5m2*)(&src.data));
1205
+ } else {
1206
+ #pragma unroll
1207
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1208
+ ((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2*)(&src.data))[i]);
1209
+ }
1210
+ }
1211
+ }
1212
+
1213
+ template <size_t vec_size>
1214
+ FLASHINFER_INLINE void cast_from_impl(vec_t<half, vec_size>& dst,
1215
+ const vec_t<__nv_fp8_e5m2, vec_size>& src) {
1216
+ if constexpr (vec_size == 1) {
1217
+ dst.data = float(src.data);
1218
+ } else {
1219
+ #pragma unroll
1220
+ for (size_t i = 0; i < vec_size / 2; ++i) {
1221
+ ((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2*)(&src.data))[i]);
1222
+ }
1223
+ }
1224
+ }
1225
+
1226
+ template <size_t vec_size>
1227
+ FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst,
1228
+ const vec_t<float, vec_size>& src) {
1229
+ if constexpr (vec_size == 1) {
1230
+ dst.data = __nv_fp8_e5m2(src.data);
1231
+ } else if constexpr (vec_size == 2) {
1232
+ *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(float2*)(&src.data));
1233
+ } else {
1234
+ #pragma unroll
1235
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1236
+ ((__nv_fp8x4_e5m2*)(&dst.data))[i] = __nv_fp8x4_e5m2(((float4*)(&src.data))[i]);
1237
+ }
1238
+ }
1239
+ }
1240
+
1241
+ template <size_t vec_size>
1242
+ FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst,
1243
+ const vec_t<half, vec_size>& src) {
1244
+ if constexpr (vec_size == 1) {
1245
+ dst.data = __nv_fp8_e5m2(src.data);
1246
+ } else if constexpr (vec_size == 2) {
1247
+ *(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(half2*)(&src.data));
1248
+ } else {
1249
+ #pragma unroll
1250
+ for (size_t i = 0; i < vec_size / 4; ++i) {
1251
+ // NOTE(Zihao): need to double check if we properly handle flo and fhi
1252
+ ((__nv_fp8x4_e5m2*)(&dst.data))[i] =
1253
+ __nv_fp8x4_e5m2(((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]);
1254
+ }
1255
+ }
1256
+ }
1257
+
1258
+ #endif // FLASHINFER_ENABLE_FP8
1259
+
1260
+ } // namespace flashinfer
1261
+
1262
+ #endif // VEC_DTYPES_CUH_
punica_kernels/punica_ops.cc ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <c10/cuda/CUDAStream.h>
2
+ #include <cuda_bf16.h>
3
+ #include <cuda_fp16.h>
4
+ #include <torch/all.h>
5
+
6
+ #include <cstdint>
7
+
8
+ #include "bgmv/bgmv_config.h"
9
+ #include "sgmv/sgmv.h"
10
+ #include "sgmv_flashinfer/sgmv_config.h"
11
+
12
+ //namespace
13
+ //{
14
+
15
+ //====== utils ======
16
+
17
+ inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b)
18
+ {
19
+ return (uint64_t(a) << 32) | uint64_t(b);
20
+ }
21
+
22
+ #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
23
+
24
+ #define CHECK_CONTIGUOUS(x) \
25
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
26
+
27
+ #define CHECK_INPUT(x) \
28
+ CHECK_CUDA(x); \
29
+ CHECK_CONTIGUOUS(x)
30
+
31
+ #define CHECK_DIM(d, x) \
32
+ TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
33
+
34
+ #define CHECK_EQ(a, b) \
35
+ TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
36
+
37
+ #define CHECK_GE(a, b) \
38
+ TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
39
+
40
+ //====== dispatch pytorch dtype ======
41
+
42
+ #define _DISPATCH_SWITCH(cond, ...) \
43
+ [&]() -> bool { \
44
+ switch (cond) \
45
+ { \
46
+ __VA_ARGS__ \
47
+ default: \
48
+ return false; \
49
+ } \
50
+ }()
51
+
52
+ #define _DISPATCH_DTYPE_CASE(enum_type, c_type_, ...) \
53
+ case enum_type: \
54
+ { \
55
+ using c_type = c_type_; \
56
+ return __VA_ARGS__(); \
57
+ }
58
+
59
+ #define _DISPATCH_DTYPE_CASES(...) \
60
+ _DISPATCH_DTYPE_CASE(at::ScalarType::Half, nv_half, __VA_ARGS__) \
61
+ _DISPATCH_DTYPE_CASE(at::ScalarType::BFloat16, nv_bfloat16, __VA_ARGS__)
62
+
63
+ #define DISPATCH_TORCH_DTYPE(scalar_type, ...) \
64
+ _DISPATCH_SWITCH(scalar_type, _DISPATCH_DTYPE_CASES(__VA_ARGS__))
65
+
66
+ //====== bgmv ======
67
+
68
+ template <typename T>
69
+ inline bool launch_bgmv_kernel(T *Y, const T *X, T **W,
70
+ const int64_t *lora_indices,
71
+ uint16_t in_features, uint16_t out_features,
72
+ int64_t y_offset, int64_t full_y_size,
73
+ int64_t batch_size,
74
+ int64_t layer_idx, float scale)
75
+ {
76
+ switch (pack_u32(in_features, out_features))
77
+ {
78
+ #define CASE_ONESIDE(_T, feat_in, feat_out) \
79
+ case pack_u32(feat_in, feat_out): \
80
+ bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
81
+ full_y_size, batch_size, \
82
+ layer_idx, scale); \
83
+ break;
84
+ #define CASE(_T, narrow, wide) \
85
+ CASE_ONESIDE(T, narrow, wide) \
86
+ CASE_ONESIDE(T, wide, narrow)
87
+
88
+ FOR_BGMV_WIDE_NARROW(CASE, _)
89
+ #undef CASE
90
+ #undef CASE_ONESIDE
91
+ default:
92
+ return false;
93
+ }
94
+
95
+ return true;
96
+ }
97
+
98
+ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
99
+ torch::Tensor indicies, int64_t layer_idx, double scale)
100
+ {
101
+ CHECK_INPUT(y);
102
+ CHECK_INPUT(x);
103
+ CHECK_INPUT(w_ptr);
104
+ CHECK_INPUT(indicies);
105
+
106
+ CHECK_DIM(2, y);
107
+ CHECK_DIM(2, x);
108
+ CHECK_DIM(1, w_ptr);
109
+ CHECK_DIM(1, indicies);
110
+
111
+ int64_t B = x.size(0);
112
+ int64_t h_in = x.size(1);
113
+ int64_t h_out = y.size(1);
114
+ CHECK_EQ(indicies.size(0), x.size(0));
115
+ CHECK_EQ(y.size(0), x.size(0));
116
+ bool ok = false;
117
+ if (h_in < 65536 && h_out < 65536)
118
+ {
119
+ switch (x.scalar_type())
120
+ {
121
+ case at::ScalarType::Half:
122
+ ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
123
+ static_cast<nv_half *>(x.data_ptr()),
124
+ static_cast<nv_half **>(w_ptr.data_ptr()),
125
+ indicies.data_ptr<int64_t>(), h_in, h_out, 0, h_out, B,
126
+ layer_idx, scale);
127
+ break;
128
+ case at::ScalarType::BFloat16:
129
+ ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
130
+ static_cast<nv_bfloat16 *>(x.data_ptr()),
131
+ static_cast<nv_bfloat16 **>(w_ptr.data_ptr()),
132
+ indicies.data_ptr<int64_t>(), h_in, h_out, 0, h_out, B,
133
+ layer_idx, scale);
134
+ break;
135
+ default:
136
+ break;
137
+ }
138
+ }
139
+ TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
140
+ " dtype=", x.scalar_type());
141
+ }
142
+
143
+ //====== sgmv ======
144
+
145
+ void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
146
+ torch::Tensor s_start, torch::Tensor s_end,
147
+ torch::Tensor tmp, int64_t layer_idx)
148
+ {
149
+ CHECK_INPUT(y);
150
+ CHECK_INPUT(x);
151
+ CHECK_INPUT(w_ptr);
152
+ CHECK_INPUT(s_start);
153
+ CHECK_INPUT(s_end);
154
+ CHECK_INPUT(tmp);
155
+
156
+ CHECK_DIM(2, y);
157
+ CHECK_DIM(2, x);
158
+ CHECK_DIM(1, w_ptr);
159
+ CHECK_DIM(1, s_start);
160
+ CHECK_DIM(1, s_end);
161
+ CHECK_DIM(1, tmp);
162
+
163
+ int num_problems = s_start.size(0);
164
+ int d_in = x.size(1);
165
+ int d_out = y.size(1);
166
+ CHECK_EQ(tmp.size(0), static_cast<int64_t>(sgmv_tmp_size(num_problems)));
167
+ cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
168
+ bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&]
169
+ { return sgmv<c_type>((c_type *)y.data_ptr(), (c_type *)x.data_ptr(), (c_type **)w_ptr.data_ptr(),
170
+ s_start.data_ptr<int32_t>(), s_end.data_ptr<int32_t>(),
171
+ tmp.data_ptr<uint8_t>(), num_problems, d_in, d_out,
172
+ layer_idx, stream); });
173
+ TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type());
174
+ }
175
+
176
+ void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
177
+ torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int64_t layer_idx)
178
+ {
179
+ CHECK_INPUT(y);
180
+ CHECK_INPUT(x);
181
+ CHECK_INPUT(w_ptr);
182
+ CHECK_INPUT(s_start);
183
+ CHECK_INPUT(s_end);
184
+ CHECK_INPUT(tmp);
185
+
186
+ CHECK_DIM(2, y);
187
+ CHECK_DIM(2, x);
188
+ CHECK_DIM(1, w_ptr);
189
+ CHECK_DIM(1, s_start);
190
+ CHECK_DIM(1, s_end);
191
+ CHECK_DIM(1, tmp);
192
+
193
+ uint32_t num_problems = s_start.size(0);
194
+ uint32_t d_in = x.size(1);
195
+ uint32_t d_out = y.size(1);
196
+ CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte);
197
+ CHECK_EQ(tmp.size(0), 8 * 1024 * 1024);
198
+ cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
199
+
200
+ #define CASE(_T, D_OUT) \
201
+ case D_OUT: \
202
+ return sgmv_shrink<c_type, D_OUT>( \
203
+ (c_type *)y.data_ptr(), (c_type *)x.data_ptr(), \
204
+ (c_type **)w_ptr.data_ptr(), s_start.data_ptr<int32_t>(), s_end.data_ptr<int32_t>(), \
205
+ tmp.data_ptr<uint8_t>(), num_problems, d_in, layer_idx, stream);
206
+
207
+ bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&]
208
+ {
209
+ switch (d_out) {
210
+ FOR_SGMV_NARROW(CASE, c_type);
211
+ default:
212
+ return false;
213
+ } });
214
+
215
+ #undef CASE
216
+ TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type(),
217
+ " d_out=", d_out);
218
+ }
219
+ //} // namespace
220
+
sgmv/sgmv.h ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cuda_runtime.h>
3
+
4
+ #include <cstdint>
5
+
6
+ template <typename DType>
7
+ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
8
+ void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx, cudaStream_t stream);
9
+
10
+ int64_t sgmv_tmp_size(int64_t num_problems);
sgmv/sgmv_cutlass.cu ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_bf16.h>
2
+ #include <cuda_fp16.h>
3
+
4
+ #include "sgmv_cutlass.cuh"
5
+
6
+ template bool sgmv<nv_half>(nv_half *y, nv_half *x, nv_half **w,
7
+ int32_t *s_start, int32_t *s_end,
8
+ void *tmp_d, int num_problems, int d_in, int d_out,
9
+ int layer_idx, cudaStream_t stream);
10
+
11
+ template bool sgmv<nv_bfloat16>(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w,
12
+ int32_t *s_start, int32_t *s_end,
13
+ void *tmp_d, int num_problems, int d_in, int d_out,
14
+ int layer_idx, cudaStream_t stream);
sgmv/sgmv_cutlass.cuh ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cuda_bf16.h>
3
+ #include <cuda_fp16.h>
4
+ #include <cuda_runtime.h>
5
+
6
+ #include <cstdint>
7
+ #include <cstdio>
8
+
9
+ #include "cutlass/cutlass.h"
10
+ #include "cutlass/gemm/device/gemm_grouped.h"
11
+ #include "cutlass/gemm/kernel/default_gemm_grouped.h"
12
+ #include "cutlass/layout/matrix.h"
13
+ #include "cutlass/numeric_types.h"
14
+
15
+ template <typename T>
16
+ struct cutlass_dtype {
17
+ using type = T;
18
+ };
19
+
20
+ template <>
21
+ struct cutlass_dtype<half> {
22
+ using type = cutlass::half_t;
23
+ };
24
+
25
+ template <>
26
+ struct cutlass_dtype<nv_bfloat16> {
27
+ using type = cutlass::bfloat16_t;
28
+ };
29
+
30
+ template <typename T>
31
+ __global__ void precompute_sgmv_args(cutlass::gemm::GemmCoord *all_problems,
32
+ T **ptr_y, T **ptr_x, T **ptr_w,
33
+ int64_t *ld_y, int64_t *ld_x,
34
+ int64_t *ld_w, T *y, T *x, T **w,
35
+ int32_t *s_start, int32_t *s_end,
36
+ int d_in, int d_out,
37
+ int layer_idx) {
38
+ int i = blockIdx.x;
39
+ int m = s_end[i] - s_start[i], k = d_in, n = d_out;
40
+ if (m <= 0) {
41
+ m = 0;
42
+ n = 0;
43
+ k = 0;
44
+ }
45
+ all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
46
+ ptr_w[i] = w[i] + layer_idx * d_in * d_out;
47
+ ptr_x[i] = x + s_start[i] * d_in;
48
+ ptr_y[i] = y + s_start[i] * d_out;
49
+ ld_x[i] = k;
50
+ ld_w[i] = n;
51
+ ld_y[i] = n;
52
+ }
53
+
54
+ int64_t sgmv_tmp_size(int64_t num_problems) {
55
+ constexpr auto sz = sizeof(void *) * 3 + sizeof(int64_t) * 3 +
56
+ sizeof(cutlass::gemm::GemmCoord);
57
+ return sz * num_problems;
58
+ }
59
+
60
+ template <typename T>
61
+ inline T *alloc_from_buf(void **buf, int n) {
62
+ auto *p = (T *)*buf;
63
+ *buf = (void *)(p + n);
64
+ return p;
65
+ }
66
+
67
+ template <typename DType>
68
+ bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
69
+ void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx,
70
+ cudaStream_t stream) {
71
+ using cutlass_t = typename cutlass_dtype<DType>::type;
72
+
73
+ auto ptr_Y = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
74
+ auto ptr_X = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
75
+ auto ptr_W = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
76
+ auto ld_Y = alloc_from_buf<int64_t>(&tmp_d, num_problems);
77
+ auto ld_X = alloc_from_buf<int64_t>(&tmp_d, num_problems);
78
+ auto ld_W = alloc_from_buf<int64_t>(&tmp_d, num_problems);
79
+ auto all_problems =
80
+ alloc_from_buf<cutlass::gemm::GemmCoord>(&tmp_d, num_problems);
81
+
82
+ precompute_sgmv_args<<<num_problems, 1, 0, stream>>>(
83
+ all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y,
84
+ (cutlass_t *)x, (cutlass_t **)w, s_start, s_end, d_in, d_out, layer_idx);
85
+
86
+ using cutlass::epilogue::thread::LinearCombination;
87
+ using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;
88
+ if (d_in < d_out) {
89
+ // Expand
90
+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
91
+ cutlass_t, // Element A
92
+ cutlass::layout::RowMajor, // Layout A
93
+ cutlass::ComplexTransform::kNone, //
94
+ 8, // Granularity A
95
+ cutlass_t, // Element B
96
+ cutlass::layout::RowMajor, // Layout B
97
+ cutlass::ComplexTransform::kNone, //
98
+ 8, // Granularity B
99
+ cutlass_t, // Element C&D
100
+ cutlass::layout::RowMajor, // Layout C&D
101
+ float, // Element Accumulator
102
+ cutlass::arch::OpClassTensorOp, // Operator Class Tag
103
+ cutlass::arch::Sm80, // Architecture
104
+ cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape
105
+ cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape
106
+ cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
107
+ LinearCombination<cutlass_t, 8, float, float>, // Epilogue
108
+ GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator
109
+ 2 // Stages
110
+ >::GemmKernel;
111
+
112
+ using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
113
+ typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0);
114
+
115
+ using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
116
+ typename GemmGrouped::Arguments args(all_problems, num_problems, 512,
117
+ epilogue_op, ptr_X, ptr_W, ptr_Y,
118
+ ptr_Y, ld_X, ld_W, ld_Y, ld_Y);
119
+
120
+ GemmGrouped gemm;
121
+ auto status = gemm.initialize(args, nullptr, stream);
122
+ if (status != cutlass::Status::kSuccess) {
123
+ fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
124
+ cutlassGetStatusString(status));
125
+ return false;
126
+ }
127
+ status = gemm.run(stream);
128
+ if (status != cutlass::Status::kSuccess) {
129
+ fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n",
130
+ cutlassGetStatusString(status));
131
+ return false;
132
+ }
133
+ } else {
134
+ // Shrink
135
+ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
136
+ cutlass_t, // Element A
137
+ cutlass::layout::RowMajor, // Layout A
138
+ cutlass::ComplexTransform::kNone, //
139
+ 8, // Granularity A
140
+ cutlass_t, // Element B
141
+ cutlass::layout::RowMajor, // Layout B
142
+ cutlass::ComplexTransform::kNone, //
143
+ 8, // Granularity B
144
+ cutlass_t, // Element C&D
145
+ cutlass::layout::RowMajor, // Layout C&D
146
+ float, // Element Accumulator
147
+ cutlass::arch::OpClassTensorOp, // Operator Class Tag
148
+ cutlass::arch::Sm80, // Architecture
149
+ cutlass::gemm::GemmShape<16, 64, 64>, // Thread Block Shape
150
+ cutlass::gemm::GemmShape<16, 16, 64>, // Warp Shape
151
+ cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
152
+ LinearCombination<cutlass_t, 4, float, float>, // Epilogue
153
+ GemmIdentityThreadblockSwizzle<2>, // Swizzling Operator
154
+ 2 // Stages
155
+ >::GemmKernel;
156
+
157
+ using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
158
+ typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0);
159
+
160
+ using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
161
+ typename GemmGrouped::Arguments args(all_problems, num_problems, 512,
162
+ epilogue_op, ptr_X, ptr_W, ptr_Y,
163
+ ptr_Y, ld_X, ld_W, ld_Y, ld_Y);
164
+
165
+ GemmGrouped gemm;
166
+ auto status = gemm.initialize(args, nullptr, stream);
167
+ if (status != cutlass::Status::kSuccess) {
168
+ fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
169
+ cutlassGetStatusString(status));
170
+ return false;
171
+ }
172
+ status = gemm.run(stream);
173
+ if (status != cutlass::Status::kSuccess) {
174
+ fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n",
175
+ cutlassGetStatusString(status));
176
+ return false;
177
+ }
178
+ }
179
+ return true;
180
+ }
sgmv_flashinfer/sgmv_all.cu ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <cuda_bf16.h>
2
+ #include <cuda_fp16.h>
3
+ #include <cuda_runtime.h>
4
+
5
+ #include <cstdint>
6
+
7
+ #include "sgmv_config.h"
8
+ #include "sgmv_flashinfer.cuh"
9
+
10
+ template <typename T, uint32_t d_out>
11
+ bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp,
12
+ uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream) {
13
+ static_assert(d_out % 16 == 0);
14
+
15
+ constexpr uint32_t num_warps = 4;
16
+ constexpr uint32_t num_stages = 2;
17
+ constexpr uint32_t num_k_frags_per_stage = 8;
18
+ constexpr uint32_t num_blocks_n = d_out / 16;
19
+ uint32_t smem = num_stages * sizeof(T) * num_k_frags_per_stage * 16 * 16 *
20
+ (num_warps + num_blocks_n);
21
+ auto cooperative_kernel =
22
+ flashinfer::sgmv::sgmv_shrink<true, T, int, num_warps, d_out>;
23
+ auto kernel = flashinfer::sgmv::sgmv_shrink<false, T, int, num_warps, d_out>;
24
+
25
+ int dev_id = 0;
26
+ int num_blocks_per_sm = 0;
27
+ int num_sm = 0;
28
+ bool use_cooperative = true;
29
+ cudaGetDevice(&dev_id);
30
+ cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id);
31
+ cudaOccupancyMaxActiveBlocksPerMultiprocessor(
32
+ &num_blocks_per_sm, cooperative_kernel, num_warps * 32, smem);
33
+
34
+ const uint32_t max_grid_size = num_sm * num_blocks_per_sm;
35
+
36
+ uint32_t chunk_size = 256;
37
+ uint32_t num_chunks = (d_in + chunk_size - 1) / chunk_size;
38
+ if (num_chunks * num_problems > max_grid_size) {
39
+ use_cooperative = false;
40
+ chunk_size = d_in;
41
+ num_chunks = 1;
42
+ }
43
+
44
+ dim3 nthrs(32, num_warps);
45
+ dim3 nblks(num_chunks, num_problems);
46
+
47
+ void* args[] = {(void*)&y, (void*)&x, (void*)&w,
48
+ (void*)&s_start, (void*)&s_end, (void*)&tmp, (void*)&num_problems,
49
+ (void*)&d_in, (void*)&layer_idx, (void*)&chunk_size};
50
+
51
+ cudaError_t status;
52
+ if (use_cooperative) {
53
+ if (smem > 46 * 1024) {
54
+ cudaFuncSetAttribute(cooperative_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
55
+ }
56
+ status = cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks,
57
+ nthrs, args, smem, stream);
58
+ } else {
59
+ if (smem > 46 * 1024) {
60
+ cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
61
+ }
62
+ status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream);
63
+ }
64
+ return status == cudaSuccess;
65
+ }
66
+
67
+ #define INST(T, d_out) \
68
+ template bool sgmv_shrink<T, d_out>(T * y, T * x, T * *w, int32_t * s_start, int32_t * s_end, \
69
+ void* tmp, uint32_t num_problems, \
70
+ uint32_t d_in, uint32_t layer_idx, cudaStream_t stream);
71
+
72
+ FOR_SGMV_NARROW(INST, nv_half);
73
+ FOR_SGMV_NARROW(INST, nv_bfloat16);
sgmv_flashinfer/sgmv_config.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cstdint>
3
+
4
+ template <typename T, uint32_t d_out>
5
+ bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp,
6
+ uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream);
7
+
8
+ // clang-format off
9
+
10
+ #define FOR_SGMV_NARROW(f, T) \
11
+ f(T, 16) \
12
+ f(T, 32) \
13
+ f(T, 64) \
14
+ f(T, 96) \
15
+ f(T, 128)
16
+
17
+ // clang-format on
sgmv_flashinfer/sgmv_flashinfer.cuh ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <cooperative_groups.h>
3
+
4
+ #include "flashinfer/cp_async.cuh"
5
+ #include "flashinfer/mma.cuh"
6
+ #include "flashinfer/permuted_smem.cuh"
7
+ #include "flashinfer/vec_dtypes.cuh"
8
+
9
+ namespace flashinfer {
10
+
11
+ namespace sgmv {
12
+
13
+ template <bool cooperative, typename T, typename IdType, uint32_t num_warps,
14
+ uint32_t d_out>
15
+ __global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s_starts, IdType* s_ends, float* tmp,
16
+ uint32_t num_problems, uint32_t d_in,
17
+ uint32_t layer_idx, uint32_t chunk_size) {
18
+ auto block = cooperative_groups::this_thread_block();
19
+ auto grid = cooperative_groups::this_grid();
20
+ constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero;
21
+ const uint32_t problem_id = blockIdx.y;
22
+ const uint32_t bx = blockIdx.x;
23
+
24
+ constexpr uint32_t num_stages = 2;
25
+ constexpr uint32_t num_k_frags = 8;
26
+ constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity<T>();
27
+ constexpr uint32_t num_blocks_n = d_out / 16;
28
+ const uint32_t num_chunks = gridDim.x;
29
+ const uint32_t chunk_start = chunk_size * bx;
30
+ const uint32_t num_iterations =
31
+ (chunk_size + (num_k_frags * 16 - 1)) / (num_k_frags * 16);
32
+ constexpr uint32_t num_cells_n =
33
+ (d_out < 32 ? 32 : d_out) / cell_capacity<T>();
34
+ const uint32_t tx = threadIdx.x, ty = threadIdx.y;
35
+
36
+ extern __shared__ uint8_t smem[];
37
+
38
+ smem_t x_smem[2]{smem, smem + sizeof(T) * num_warps * 16 * 16 * num_k_frags};
39
+ smem_t w_smem[2]{smem + sizeof(T) * 2 * num_warps * 16 * 16 * num_k_frags,
40
+ smem + sizeof(T) * 16 * 16 * num_k_frags *
41
+ (2 * num_warps + num_blocks_n)};
42
+ smem_t y_smem(smem);
43
+
44
+ uint32_t x_frag[num_k_frags][4];
45
+ uint32_t w_frag[num_k_frags][num_blocks_n][4];
46
+ float y_frag[num_blocks_n][8];
47
+
48
+ const uint32_t s_start = s_starts[problem_id], s_end = s_ends[problem_id];
49
+ const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0;
50
+ for (uint32_t i = 0; i < num_steps; ++i) {
51
+ // init y_frag
52
+ if (bx == 0) {
53
+ if constexpr (num_blocks_n == 1) {
54
+ uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2;
55
+ T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity<T>();
56
+ auto offset =
57
+ smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 2, tx % 2);
58
+ y_smem.load_128b_async<fill_mode>(offset, y_ptr, row_idx < s_end);
59
+ } else {
60
+ uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
61
+ T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity<T>();
62
+ auto offset =
63
+ smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 4, tx % 4);
64
+ #pragma unroll
65
+ for (uint32_t j = 0; j < 2; ++j) {
66
+ #pragma unroll
67
+ for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) {
68
+ y_smem.load_128b_async<fill_mode>(offset, y_ptr, row_idx < s_end);
69
+ y_ptr += 4 * cell_capacity<T>();
70
+ offset += 8;
71
+ }
72
+ row_idx += 8;
73
+ y_ptr += 8 * d_out - 2 * num_blocks_n * cell_capacity<T>();
74
+ offset += 8 * num_cells_n - 4 * num_blocks_n;
75
+ }
76
+ }
77
+ cp_async::commit_group();
78
+ cp_async::wait_group<0>();
79
+ block.sync();
80
+
81
+ auto offset =
82
+ smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx % 16, tx / 16);
83
+ #pragma unroll
84
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
85
+ uint32_t tmp[4];
86
+ y_smem.ldmatrix_m8n8x4(offset, tmp);
87
+ vec_cast<float, T, 8>(y_frag[fn], (T*)tmp);
88
+ offset = (offset ^ 0x2) + (fn & 0x1) * 8;
89
+ }
90
+ } else {
91
+ #pragma unroll
92
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
93
+ #pragma unroll
94
+ for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
95
+ y_frag[fn][reg_id] = 0.f;
96
+ }
97
+ }
98
+ }
99
+
100
+ // preload x_smem, w_smem
101
+ #pragma unroll
102
+ for (uint32_t iter = 0; iter < num_stages; ++iter) {
103
+ uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
104
+ T* x_ptr = x + row_idx * d_in + chunk_start +
105
+ (2 * num_k_frags * iter + tx % 4) * cell_capacity<T>();
106
+ T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size);
107
+ auto offset =
108
+ smem_t::get_permuted_offset<num_cells_k>(ty * 16 + tx / 4, tx % 4);
109
+ // pre-load x_smem, w_smem
110
+ #pragma unroll
111
+ for (uint32_t j = 0; j < 2; ++j) {
112
+ #pragma unroll
113
+ for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) {
114
+ x_smem[iter].load_128b_async<fill_mode>(
115
+ offset, x_ptr, row_idx < s_end && x_ptr < x_ptr_max);
116
+ x_ptr += 4 * cell_capacity<T>();
117
+ offset += 8;
118
+ }
119
+ row_idx += 8;
120
+ x_ptr += 8 * d_in - 2 * cell_capacity<T>() * num_k_frags;
121
+ x_ptr_max += 8 * d_in;
122
+ offset += 8 * num_cells_k - 4 * num_k_frags;
123
+ }
124
+ row_idx -= 8;
125
+
126
+ static_assert(num_k_frags % (num_warps * 2) == 0);
127
+ constexpr uint32_t num_fko_iters_per_warp = num_k_frags / (num_warps * 2);
128
+ #pragma unroll
129
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
130
+ T* w_ptr = w[problem_id] + layer_idx * d_in * d_out +
131
+ (fn * 16 + tx / 4) * d_in + chunk_start +
132
+ (2 * num_k_frags * iter + ty * num_fko_iters_per_warp * 4 +
133
+ tx % 4) *
134
+ cell_capacity<T>();
135
+ T* w_ptr_max =
136
+ w[problem_id] + layer_idx * d_in * d_out +
137
+ min((fn * 16 + tx / 4 + 1) * d_in,
138
+ (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size);
139
+ auto offset = smem_t::get_permuted_offset<num_cells_k>(
140
+ fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4);
141
+ #pragma unroll
142
+ for (uint32_t j = 0; j < 2; ++j) {
143
+ #pragma unroll
144
+ for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) {
145
+ w_smem[iter].load_128b_async<fill_mode>(offset, w_ptr,
146
+ w_ptr < w_ptr_max);
147
+ w_ptr += 4 * cell_capacity<T>();
148
+ offset += 8;
149
+ }
150
+ w_ptr += 8 * d_in - 4 * cell_capacity<T>() * num_fko_iters_per_warp;
151
+ w_ptr_max += 8 * d_in;
152
+ offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp;
153
+ }
154
+ }
155
+ cp_async::commit_group();
156
+ }
157
+
158
+ #pragma unroll 1
159
+ for (uint32_t iter = 0; iter < num_iterations; ++iter) {
160
+ const uint32_t stage_idx = iter % 2;
161
+ cp_async::wait_group<1>();
162
+ block.sync();
163
+
164
+ auto offset =
165
+ smem_t::get_permuted_offset<num_cells_k>(ty * 16 + tx % 16, tx / 16);
166
+ #pragma unroll
167
+ for (uint32_t fk = 0; fk < num_k_frags; ++fk) {
168
+ x_smem[stage_idx].ldmatrix_m8n8x4(offset, x_frag[fk]);
169
+ offset = (offset ^ 0x2) + (fk & 0x1) * 8;
170
+ }
171
+
172
+ #pragma unroll
173
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
174
+ auto offset = smem_t::get_permuted_offset<num_cells_k>(
175
+ fn * 16 + 8 * (tx / 16) + tx % 8, (tx % 16) / 8);
176
+ #pragma unroll
177
+ for (uint32_t fk = 0; fk < num_k_frags; ++fk) {
178
+ w_smem[stage_idx].ldmatrix_m8n8x4(offset, w_frag[fk][fn]);
179
+ offset = (offset ^ 0x2) + (fk & 0x1) * 8;
180
+ }
181
+ offset += 16 * num_cells_k - 4 * num_k_frags;
182
+ }
183
+
184
+ // compute y_frag
185
+ #pragma unroll
186
+ for (uint32_t fk = 0; fk < num_k_frags; ++fk) {
187
+ #pragma unroll
188
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
189
+ mma::mma_sync_m16n16k16_row_col_f16f16f32<T>(y_frag[fn], x_frag[fk],
190
+ w_frag[fk][fn]);
191
+ }
192
+ }
193
+ block.sync();
194
+
195
+ // load next stage
196
+ if (iter + num_stages < num_iterations) {
197
+ uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
198
+ T* x_ptr = x + row_idx * d_in + chunk_start +
199
+ (2 * num_k_frags * (iter + num_stages) + tx % 4) *
200
+ cell_capacity<T>();
201
+ T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size);
202
+ auto offset =
203
+ smem_t::get_permuted_offset<num_cells_k>(ty * 16 + tx / 4, tx % 4);
204
+ // pre-load x_smem, w_smem
205
+ #pragma unroll
206
+ for (uint32_t j = 0; j < 2; ++j) {
207
+ #pragma unroll
208
+ for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) {
209
+ x_smem[stage_idx].load_128b_async<fill_mode>(
210
+ offset, x_ptr, row_idx < s_end && x_ptr < x_ptr_max);
211
+ x_ptr += 4 * cell_capacity<T>();
212
+ offset += 8;
213
+ }
214
+ row_idx += 8;
215
+ x_ptr += 8 * d_in - 2 * cell_capacity<T>() * num_k_frags;
216
+ x_ptr_max += 8 * d_in;
217
+ offset += 8 * num_cells_k - 4 * num_k_frags;
218
+ }
219
+ row_idx -= 8;
220
+
221
+ constexpr uint32_t num_fko_iters_per_warp =
222
+ num_k_frags / (num_warps * 2);
223
+ #pragma unroll
224
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
225
+ T* w_ptr = w[problem_id] + layer_idx * d_in * d_out +
226
+ (fn * 16 + tx / 4) * d_in + chunk_start +
227
+ (2 * num_k_frags * (iter + num_stages) +
228
+ ty * num_fko_iters_per_warp * 4 + tx % 4) *
229
+ cell_capacity<T>();
230
+ T* w_ptr_max =
231
+ w[problem_id] + layer_idx * d_in * d_out +
232
+ min((fn * 16 + tx / 4 + 1) * d_in,
233
+ (fn * 16 + tx / 4) * d_in + chunk_start + chunk_size);
234
+ auto offset = smem_t::get_permuted_offset<num_cells_k>(
235
+ fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4);
236
+ #pragma unroll
237
+ for (uint32_t j = 0; j < 2; ++j) {
238
+ #pragma unroll
239
+ for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) {
240
+ w_smem[stage_idx].load_128b_async<fill_mode>(offset, w_ptr,
241
+ w_ptr < w_ptr_max);
242
+ w_ptr += 4 * cell_capacity<T>();
243
+ offset += 8;
244
+ }
245
+ w_ptr += 8 * d_in - 4 * cell_capacity<T>() * num_fko_iters_per_warp;
246
+ w_ptr_max += 8 * d_in;
247
+ offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp;
248
+ }
249
+ }
250
+ }
251
+ cp_async::commit_group();
252
+ }
253
+ cp_async::wait_group<0>();
254
+ block.sync();
255
+
256
+ if constexpr (cooperative) {
257
+ #pragma unroll
258
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
259
+ vec_t<float, 8>::memcpy(
260
+ tmp + (fn * grid.size() +
261
+ (problem_id * num_chunks + bx) * block.num_threads() +
262
+ block.thread_rank()) *
263
+ 8,
264
+ y_frag[fn]);
265
+ }
266
+ grid.sync();
267
+
268
+ #pragma unroll
269
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
270
+ #pragma unroll
271
+ for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
272
+ y_frag[fn][reg_id] = 0.f;
273
+ }
274
+ for (uint32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
275
+ vec_t<float, 8> y_other;
276
+ y_other.load(tmp + (fn * grid.size() +
277
+ (problem_id * num_chunks + chunk_idx) *
278
+ block.num_threads() +
279
+ block.thread_rank()) *
280
+ 8);
281
+ #pragma unroll
282
+ for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
283
+ y_frag[fn][reg_id] += y_other[reg_id];
284
+ }
285
+ }
286
+ }
287
+ }
288
+
289
+ if (bx == 0) {
290
+ // store y_frag
291
+ auto offset =
292
+ smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 4, 0);
293
+ #pragma unroll
294
+ for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
295
+ vec_cast<T, float, 2>((T*)(y_smem.base + offset) + (tx % 4) * 2,
296
+ &y_frag[fn][0]);
297
+ vec_cast<T, float, 2>(
298
+ (T*)(y_smem.base + offset + 8 * num_cells_n) + (tx % 4) * 2,
299
+ &y_frag[fn][2]);
300
+ vec_cast<T, float, 2>((T*)(y_smem.base + (offset ^ 0x1)) + (tx % 4) * 2,
301
+ &y_frag[fn][4]);
302
+ vec_cast<T, float, 2>(
303
+ (T*)(y_smem.base + (offset ^ 0x1) + 8 * num_cells_n) + (tx % 4) * 2,
304
+ &y_frag[fn][6]);
305
+ offset = (offset ^ 0x2) + (fn & 0x1) * 8;
306
+ }
307
+
308
+ // store y
309
+ if constexpr (num_blocks_n == 1) {
310
+ uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2;
311
+ T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity<T>();
312
+ auto offset =
313
+ smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 2, tx % 2);
314
+ if (row_idx < s_end) {
315
+ y_smem.store_128b(offset, y_ptr);
316
+ }
317
+ } else {
318
+ uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
319
+ T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity<T>();
320
+ auto offset =
321
+ smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 4, tx % 4);
322
+ #pragma unroll
323
+ for (uint32_t j = 0; j < 2; ++j) {
324
+ #pragma unroll
325
+ for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) {
326
+ if (row_idx < s_end) {
327
+ y_smem.store_128b(offset, y_ptr);
328
+ }
329
+ y_ptr += 4 * cell_capacity<T>();
330
+ offset += 8;
331
+ }
332
+ row_idx += 8;
333
+ y_ptr += 8 * d_out - 2 * num_blocks_n * cell_capacity<T>();
334
+ offset += 8 * num_cells_n - 4 * num_blocks_n;
335
+ }
336
+ }
337
+ }
338
+ }
339
+
340
+ // handle the case where one of the segments needs more steps than this one
341
+ // to avoid deadlock
342
+ if constexpr (cooperative) {
343
+ uint32_t max_segment_size = 0;
344
+ for (uint32_t i = 0; i < num_problems; ++i) {
345
+ max_segment_size = max(max_segment_size, s_ends[i] - s_starts[i]);
346
+ }
347
+
348
+ const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16);
349
+ for (uint32_t i = 0; i < max_steps - num_steps; ++i) {
350
+ grid.sync();
351
+ }
352
+ }
353
+ }
354
+
355
+ } // namespace sgmv
356
+ } // namespace flashinfer
tests/test_sgmv.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from punica_sgmv import (
7
+ get_tmp_tensors,
8
+ lora_a_sgmv_cutlass,
9
+ lora_b_sgmv_cutlass,
10
+ pad_rank,
11
+ use_cutlass_shrink,
12
+ )
13
+
14
+
15
+ def lora_ref_impl(
16
+ y: torch.Tensor,
17
+ x: torch.Tensor,
18
+ wa: List[torch.Tensor],
19
+ wb: List[torch.Tensor],
20
+ s_start: torch.IntTensor,
21
+ s_end: torch.IntTensor,
22
+ layer_idx: int,
23
+ lora_rank: int,
24
+ ):
25
+ for i in range(len(wa)):
26
+ if s_end[i] - s_start[i] <= 0:
27
+ continue
28
+
29
+ xi = x[s_start[i]:s_end[i]]
30
+ wai = wa[i][layer_idx, :, :]
31
+ wbi = wb[i][layer_idx, :, :]
32
+
33
+ if not use_cutlass_shrink(lora_rank):
34
+ wai = wai.t()
35
+
36
+ yi = y[s_start[i]:s_end[i]]
37
+ tmp = (xi @ wai)
38
+ y[s_start[i]:s_end[i]] = (yi + tmp @ wbi)
39
+
40
+
41
+ @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
42
+ @pytest.mark.parametrize("segments", [
43
+ ([0, 2], [1, 3]),
44
+ ([0, -1], [1, -1]),
45
+ ])
46
+ @pytest.mark.parametrize("lora_rank", [8, 16, 32, 64, 128])
47
+ def test_add_lora_sgmv(lora_rank: int, segments: Tuple[List[int], List[int]]):
48
+ torch.manual_seed(42)
49
+
50
+ B = 3
51
+ H = 1024
52
+ r = lora_rank
53
+ nlayers = 2
54
+
55
+ device = torch.device("cuda:0")
56
+
57
+ y = torch.zeros((B, H), dtype=torch.float16, device=device)
58
+ x = torch.randn((B, H), dtype=torch.float16, device=device)
59
+ wa = torch.randn(nlayers, r, H, dtype=torch.float16, device=device)
60
+ if use_cutlass_shrink(r):
61
+ # cutlass uses (H, r) layout
62
+ wa = wa.transpose(1, 2).contiguous()
63
+
64
+ # TODO(travis): transpose (r, H) -> (H, r) when not using cutlass
65
+ wb = torch.randn(nlayers, r, H, dtype=torch.float16, device=device)
66
+
67
+ s1, s2 = segments
68
+ s_start = torch.tensor(s1, dtype=torch.int32, device=device)
69
+ s_end = torch.tensor(s2, dtype=torch.int32, device=device)
70
+
71
+ wa_list = [wa if y - x > 0 else None for x, y in zip(s1, s2)]
72
+ wb_list = [wb if y - x > 0 else None for x, y in zip(s1, s2)]
73
+
74
+ wa_ptr = torch.tensor([wa.data_ptr() if wa is not None else 0 for wa in wa_list], dtype=torch.int64, device=device)
75
+ wb_ptr = torch.tensor([wb.data_ptr() if wb is not None else 0 for wb in wb_list], dtype=torch.int64, device=device)
76
+
77
+ layer_idx = 0
78
+
79
+ y_ref = y.clone()
80
+ lora_ref_impl(y_ref, x, wa_list, wb_list, s_start, s_end, layer_idx, r)
81
+
82
+ tmp_shrink, tmp_expand = get_tmp_tensors(wa_ptr.size(0), r, x.device)
83
+ y_ours = torch.zeros((B, H), dtype=torch.float16, device=device)
84
+
85
+ v = lora_a_sgmv_cutlass(x, tmp_shrink, wa_ptr, s_start, s_end, layer_idx, r)
86
+ lora_b_sgmv_cutlass(y_ours, v, tmp_expand, wb_ptr, s_start, s_end, layer_idx)
87
+
88
+ assert torch.allclose(y_ref, y_ours, rtol=1e-2, atol=1e-3)
89
+
90
+ # graph trace
91
+ tmp_shrink, tmp_expand = get_tmp_tensors(wa_ptr.size(0), r, x.device)
92
+ y_ours_graph = torch.zeros((B, H), dtype=torch.float16, device=device)
93
+
94
+ torch.cuda.synchronize(device)
95
+ graph = torch.cuda.CUDAGraph()
96
+ with torch.cuda.graph(graph, pool=None):
97
+ v = lora_a_sgmv_cutlass(x, tmp_shrink, wa_ptr, s_start, s_end, layer_idx, r)
98
+ lora_b_sgmv_cutlass(y_ours_graph, v, tmp_expand, wb_ptr, s_start, s_end, layer_idx)
99
+
100
+ torch.cuda.synchronize(device)
101
+ graph.replay()
102
+
103
+ assert torch.allclose(y_ours, y_ours_graph, rtol=1e-2, atol=1e-3)
104
+
105
+
106
+ @pytest.mark.parametrize("world_size", [1, 2, 4, 8])
107
+ @pytest.mark.parametrize("lora_rank", [8, 16, 32, 64, 128])
108
+ def test_pad_rank(lora_rank: int, world_size: int):
109
+ bs = 8
110
+ h = 1024
111
+ x = torch.randn((bs, h), dtype=torch.float16)
112
+
113
+ lora_a = torch.randn((h, lora_rank), dtype=torch.float16)
114
+ lora_b = torch.randn((lora_rank, h), dtype=torch.float16)
115
+
116
+ lora_a_padded = pad_rank(lora_a, dim=1, world_size=world_size)
117
+ lora_b_padded = pad_rank(lora_b, dim=0, world_size=world_size)
118
+
119
+ assert lora_a_padded.size(1) == lora_b_padded.size(0)
120
+ assert lora_a_padded.size(1) >= lora_a.size(1)
121
+ assert lora_b_padded.size(0) >= lora_b.size(0)
122
+
123
+ expected = x @ lora_a @ lora_b
124
+ actual = x @ lora_a_padded @ lora_b_padded
125
+ assert torch.allclose(expected, actual)
torch-ext/punica_sgmv/__init__.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from functools import lru_cache
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+ from ._ops import ops
8
+
9
+ MIN_SGMV_RANK = 8
10
+ MIN_RANK_CUSTOM = 16
11
+ MAX_RANK_CUSTOM = 128
12
+ SGMV_BLOCK_SIZE = 16
13
+ BGMV_MAX_RANK = 128
14
+
15
+ def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
16
+ if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
17
+ return t.transpose(0, 1)
18
+ return t
19
+
20
+ def add_lora_sgmv_cutlass(
21
+ y: torch.Tensor,
22
+ x: torch.Tensor,
23
+ wa_ptr: torch.Tensor,
24
+ wb_ptr: torch.Tensor,
25
+ s_start: torch.Tensor,
26
+ s_end: torch.Tensor,
27
+ layer_idx: int,
28
+ lora_rank: int,
29
+ ):
30
+ """
31
+ Semantics:
32
+ y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
33
+
34
+ Args:
35
+ y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
36
+ x: Shape: `[B, H1]`. Input vectors.
37
+ wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
38
+ Weight matrix shape: `[num_layers, R, H1]`.
39
+ wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
40
+ Weight matrix shape: `[num_layers, R, H2]`.
41
+ s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
42
+ s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
43
+ layer_idx: Layer index of the weight matrices.
44
+ """
45
+ if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
46
+ # Custom SGMV shrink only supports rank 16, 32, 64, 128
47
+ _add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
48
+ return
49
+
50
+ tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
51
+ tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
52
+ tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
53
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
54
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
55
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
56
+
57
+ def _add_lora_sgmv_cutlass_legacy(
58
+ y: torch.Tensor,
59
+ x: torch.Tensor,
60
+ wa_ptr: torch.Tensor,
61
+ wb_ptr: torch.Tensor,
62
+ s_start: torch.IntTensor,
63
+ s_end: torch.IntTensor,
64
+ layer_idx: int,
65
+ lora_rank: int,
66
+ ):
67
+ tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
68
+ tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
69
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
70
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
71
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
72
+
73
+ def lora_a_sgmv_cutlass(
74
+ x: torch.Tensor,
75
+ tmp: torch.Tensor,
76
+ wa_ptr: torch.Tensor,
77
+ s_start: torch.IntTensor,
78
+ s_end: torch.IntTensor,
79
+ layer_idx: int,
80
+ lora_rank: int,
81
+ ) -> torch.Tensor:
82
+ v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
83
+ if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
84
+ ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
85
+ else:
86
+ ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
87
+ return v
88
+
89
+
90
+ def lora_b_sgmv_cutlass(
91
+ y: torch.Tensor,
92
+ v: torch.Tensor,
93
+ tmp: torch.Tensor,
94
+ wb_ptr: torch.Tensor,
95
+ s_start: torch.IntTensor,
96
+ s_end: torch.IntTensor,
97
+ layer_idx: int,
98
+ ):
99
+ ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
100
+
101
+ def add_lora_a_bgmv(
102
+ v: torch.Tensor,
103
+ x: torch.Tensor,
104
+ wa_T_all: torch.Tensor,
105
+ indicies: torch.LongTensor,
106
+ layer_idx: int,
107
+ ):
108
+ ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
109
+
110
+
111
+ def add_lora_b_bgmv(
112
+ y: torch.Tensor,
113
+ v: torch.Tensor,
114
+ wb_T_all: torch.Tensor,
115
+ indicies: torch.LongTensor,
116
+ layer_idx: int,
117
+ ):
118
+ ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
119
+
120
+
121
+ def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
122
+ """Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
123
+ # tensor parallelism will result in effective rank being divided by world_size,
124
+ # so we need to scale the min rank to offset that effect
125
+ min_rank = MIN_SGMV_RANK * world_size
126
+ return pad_to_min_rank(t, dim, min_rank)
127
+
128
+ def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
129
+ # if we're at or below the min rank, pad up to the min rank
130
+ # otherwise, pad to the nearest multiple of the block size
131
+ current_rank = t.size(dim)
132
+ target_rank = (
133
+ min_rank
134
+ if current_rank <= min_rank
135
+ else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
136
+ )
137
+ if current_rank == target_rank:
138
+ return t
139
+
140
+ pad_size = target_rank - current_rank
141
+
142
+ # see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
143
+ pad = [0, 0] * t.dim()
144
+ pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
145
+ pad = tuple(pad)
146
+
147
+ return F.pad(t, pad, mode="constant", value=0.0)
148
+
149
+ def use_cutlass_shrink(lora_rank: int) -> bool:
150
+ return lora_rank < MIN_RANK_CUSTOM
151
+
152
+ @lru_cache(maxsize=1)
153
+ def get_tmp_tensor(device: torch.device) -> torch.Tensor:
154
+ return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
155
+
156
+ @lru_cache(maxsize=32)
157
+ def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
158
+ tmp_size = ops.sgmv_cutlass_tmp_size(size)
159
+ return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
160
+
161
+ def get_tmp_expand_size(size: int) -> int:
162
+ return ops.sgmv_cutlass_tmp_size(size)
163
+
164
+
165
+ def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
166
+ if use_cutlass_shrink(lora_rank):
167
+ tmp = get_tmp_tensor_for_size(nsegments, device)
168
+ return tmp, tmp
169
+ else:
170
+ tmp_shrink = get_tmp_tensor(device)
171
+ tmp_expand = get_tmp_tensor_for_size(nsegments, device)
172
+ return tmp_shrink, tmp_expand
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("sgmv_shrink(Tensor! y, Tensor x, Tensor w_ptr, Tensor s_start, "
8
+ "Tensor s_end, Tensor! tmp, int layer_idx) -> ()");
9
+ ops.impl("sgmv_shrink", torch::kCUDA, &dispatch_sgmv_shrink);
10
+
11
+ ops.def("sgmv_cutlass(Tensor! y, Tensor x, Tensor w_ptr, Tensor s_start, "
12
+ "Tensor s_end, Tensor! tmp, int layer_idx) -> ()");
13
+ ops.impl("sgmv_cutlass", torch::kCUDA, &dispatch_sgmv_cutlass);
14
+
15
+ ops.def("sgmv_cutlass_tmp_size(int num_problems) -> int");
16
+ ops.impl("sgmv_cutlass_tmp_size", &sgmv_tmp_size);
17
+
18
+ ops.def("dispatch_bgmv(Tensor! y, Tensor x, Tensor w_ptr, Tensor indices, "
19
+ "int layer_indices, float scale) -> ()");
20
+ ops.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
21
+ }
22
+
23
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
6
+ torch::Tensor indicies, int64_t layer_idx, double scale);
7
+
8
+ void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
9
+ torch::Tensor s_start, torch::Tensor s_end,
10
+ torch::Tensor tmp, int64_t layer_idx);
11
+
12
+ void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
13
+ torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int64_t layer_idx);
14
+
15
+
16
+ int64_t sgmv_tmp_size(int64_t num_problems);