Commit
·
5a84343
0
Parent(s):
Add Punica sgmv kernels
Browse files- README.md +11 -0
- bgmv/bgmv_all.cu +5 -0
- bgmv/bgmv_config.h +88 -0
- bgmv/bgmv_impl.cuh +296 -0
- build.toml +53 -0
- flake.lock +117 -0
- flake.nix +17 -0
- flashinfer/cp_async.cuh +187 -0
- flashinfer/mma.cuh +410 -0
- flashinfer/permuted_smem.cuh +95 -0
- flashinfer/vec_dtypes.cuh +1262 -0
- punica_kernels/punica_ops.cc +220 -0
- sgmv/sgmv.h +10 -0
- sgmv/sgmv_cutlass.cu +14 -0
- sgmv/sgmv_cutlass.cuh +180 -0
- sgmv_flashinfer/sgmv_all.cu +73 -0
- sgmv_flashinfer/sgmv_config.h +17 -0
- sgmv_flashinfer/sgmv_flashinfer.cuh +356 -0
- tests/test_sgmv.py +125 -0
- torch-ext/punica_sgmv/__init__.py +172 -0
- torch-ext/torch_binding.cpp +23 -0
- torch-ext/torch_binding.h +16 -0
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
tags:
|
4 |
+
- kernel
|
5 |
+
---
|
6 |
+
|
7 |
+
## Punica sgmv
|
8 |
+
|
9 |
+
[Punica](https://github.com/punica-ai/punica) sgmv kernels with modifications
|
10 |
+
from [Lorax](https://github.com/predibase/lorax).
|
11 |
+
|
bgmv/bgmv_all.cu
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include "bgmv_config.h"
|
2 |
+
#include "bgmv_impl.cuh"
|
3 |
+
|
4 |
+
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half)
|
5 |
+
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16)
|
bgmv/bgmv_config.h
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
template <int feat_in, int feat_out, typename T>
|
4 |
+
void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X,
|
5 |
+
T **__restrict__ W,
|
6 |
+
const int64_t *__restrict__ indicies, int64_t y_offset,
|
7 |
+
int64_t full_y_size, int64_t batch_size,
|
8 |
+
int64_t layer_idx, float scale);
|
9 |
+
|
10 |
+
// clang-format off
|
11 |
+
|
12 |
+
#define FOR_BGMV_WIDE(f, T, narrow) \
|
13 |
+
f(T, narrow, 256) \
|
14 |
+
f(T, narrow, 512) \
|
15 |
+
f(T, narrow, 640) \
|
16 |
+
f(T, narrow, 768) \
|
17 |
+
f(T, narrow, 1024) \
|
18 |
+
f(T, narrow, 1152) \
|
19 |
+
f(T, narrow, 1280) \
|
20 |
+
f(T, narrow, 1536) \
|
21 |
+
f(T, narrow, 1728) \
|
22 |
+
f(T, narrow, 1792) \
|
23 |
+
f(T, narrow, 2048) \
|
24 |
+
f(T, narrow, 2304) \
|
25 |
+
f(T, narrow, 2560) \
|
26 |
+
f(T, narrow, 2752) \
|
27 |
+
f(T, narrow, 2816) \
|
28 |
+
f(T, narrow, 3072) \
|
29 |
+
f(T, narrow, 3456) \
|
30 |
+
f(T, narrow, 3584) \
|
31 |
+
f(T, narrow, 4096) \
|
32 |
+
f(T, narrow, 4480) \
|
33 |
+
f(T, narrow, 4608) \
|
34 |
+
f(T, narrow, 5120) \
|
35 |
+
f(T, narrow, 5504) \
|
36 |
+
f(T, narrow, 5632) \
|
37 |
+
f(T, narrow, 6144) \
|
38 |
+
f(T, narrow, 6848) \
|
39 |
+
f(T, narrow, 6912) \
|
40 |
+
f(T, narrow, 7168) \
|
41 |
+
f(T, narrow, 7680) \
|
42 |
+
f(T, narrow, 8192) \
|
43 |
+
f(T, narrow, 8960) \
|
44 |
+
f(T, narrow, 9216) \
|
45 |
+
f(T, narrow, 9472) \
|
46 |
+
f(T, narrow, 10240) \
|
47 |
+
f(T, narrow, 11008) \
|
48 |
+
f(T, narrow, 12288) \
|
49 |
+
f(T, narrow, 13696) \
|
50 |
+
f(T, narrow, 13824) \
|
51 |
+
f(T, narrow, 14336) \
|
52 |
+
f(T, narrow, 15360) \
|
53 |
+
f(T, narrow, 16384) \
|
54 |
+
f(T, narrow, 17920) \
|
55 |
+
f(T, narrow, 18944) \
|
56 |
+
f(T, narrow, 20480) \
|
57 |
+
f(T, narrow, 22016) \
|
58 |
+
f(T, narrow, 24576) \
|
59 |
+
f(T, narrow, 27392) \
|
60 |
+
f(T, narrow, 27648) \
|
61 |
+
f(T, narrow, 28672) \
|
62 |
+
f(T, narrow, 32000) \
|
63 |
+
f(T, narrow, 32256) \
|
64 |
+
f(T, narrow, 32512) \
|
65 |
+
f(T, narrow, 32768) \
|
66 |
+
f(T, narrow, 33024) \
|
67 |
+
f(T, narrow, 35840) \
|
68 |
+
f(T, narrow, 36864) \
|
69 |
+
f(T, narrow, 43264) \
|
70 |
+
f(T, narrow, 49152) \
|
71 |
+
f(T, narrow, 64000) \
|
72 |
+
f(T, narrow, 64256) \
|
73 |
+
f(T, narrow, 64512) \
|
74 |
+
f(T, narrow, 102400) \
|
75 |
+
f(T, narrow, 102656) \
|
76 |
+
f(T, narrow, 102912) \
|
77 |
+
f(T, narrow, 128000) \
|
78 |
+
f(T, narrow, 128256) \
|
79 |
+
f(T, narrow, 128512) \
|
80 |
+
|
81 |
+
#define FOR_BGMV_WIDE_NARROW(f, T) \
|
82 |
+
FOR_BGMV_WIDE(f, T, 8) \
|
83 |
+
FOR_BGMV_WIDE(f, T, 16) \
|
84 |
+
FOR_BGMV_WIDE(f, T, 32) \
|
85 |
+
FOR_BGMV_WIDE(f, T, 64) \
|
86 |
+
FOR_BGMV_WIDE(f, T, 128)
|
87 |
+
|
88 |
+
// clang-format on
|
bgmv/bgmv_impl.cuh
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <ATen/cuda/CUDAContext.h>
|
4 |
+
#include <cooperative_groups.h>
|
5 |
+
#include <cuda/pipeline>
|
6 |
+
#include <cuda_runtime.h>
|
7 |
+
#include <iostream>
|
8 |
+
#include <stdio.h>
|
9 |
+
|
10 |
+
#include "flashinfer/vec_dtypes.cuh"
|
11 |
+
|
12 |
+
namespace cg = cooperative_groups;
|
13 |
+
|
14 |
+
// nthrs = (32, 4)
|
15 |
+
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
16 |
+
size_t W_copy_size, int tx, int ty, int tz, typename T>
|
17 |
+
__global__ void
|
18 |
+
bgmv_shrink_kernel(T* __restrict__ Y, const T* __restrict__ X,
|
19 |
+
T** __restrict__ W,
|
20 |
+
const int64_t* __restrict__ indicies, int64_t y_offset,
|
21 |
+
int64_t full_y_size, int64_t layer_idx,
|
22 |
+
float scale) {
|
23 |
+
size_t batch_idx = blockIdx.y;
|
24 |
+
int64_t idx = indicies[batch_idx];
|
25 |
+
if (idx < 0) {
|
26 |
+
return;
|
27 |
+
}
|
28 |
+
|
29 |
+
auto block = cg::this_thread_block();
|
30 |
+
size_t j = blockIdx.x;
|
31 |
+
constexpr size_t num_pipeline_stages = 2;
|
32 |
+
constexpr size_t tile_size = tx * ty * vec_size;
|
33 |
+
__shared__ T W_shared[num_pipeline_stages * tile_size];
|
34 |
+
__shared__ T X_shared[num_pipeline_stages * tile_size];
|
35 |
+
__shared__ float y_warpwise[ty];
|
36 |
+
|
37 |
+
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
38 |
+
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
39 |
+
auto pipe = cuda::make_pipeline();
|
40 |
+
|
41 |
+
const T* W_ptr = W[idx];
|
42 |
+
|
43 |
+
// pipeline load W/X and compute WX;
|
44 |
+
pipe.producer_acquire();
|
45 |
+
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
46 |
+
W_ptr + (layer_idx * feat_out + j) * feat_in +
|
47 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
48 |
+
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
49 |
+
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
50 |
+
X + (batch_idx * feat_in) +
|
51 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
52 |
+
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
53 |
+
pipe.producer_commit();
|
54 |
+
size_t copy_idx, compute_idx;
|
55 |
+
float y = 0.f;
|
56 |
+
flashinfer::vec_t<T, vec_size> x_vec;
|
57 |
+
flashinfer::vec_t<T, vec_size> w_vec;
|
58 |
+
size_t tile_idx;
|
59 |
+
|
60 |
+
#pragma unroll
|
61 |
+
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
62 |
+
++tile_idx) {
|
63 |
+
copy_idx = tile_idx % num_pipeline_stages;
|
64 |
+
// pipeline stage: async copy W fragment
|
65 |
+
pipe.producer_acquire();
|
66 |
+
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
67 |
+
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
68 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
69 |
+
W_ptr + (layer_idx * feat_out + j) * feat_in +
|
70 |
+
tile_idx * tile_size +
|
71 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
72 |
+
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
73 |
+
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
74 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
75 |
+
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
76 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
77 |
+
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
78 |
+
}
|
79 |
+
pipe.producer_commit();
|
80 |
+
|
81 |
+
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
82 |
+
// pipeline stage: compute WX
|
83 |
+
pipe.consumer_wait();
|
84 |
+
block.sync();
|
85 |
+
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
86 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
87 |
+
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
88 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
89 |
+
float sum = 0.f;
|
90 |
+
#pragma unroll
|
91 |
+
for (size_t i = 0; i < vec_size; ++i) {
|
92 |
+
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
93 |
+
}
|
94 |
+
#pragma unroll
|
95 |
+
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
96 |
+
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
97 |
+
}
|
98 |
+
y_warpwise[threadIdx.y] = sum;
|
99 |
+
block.sync();
|
100 |
+
#pragma unroll
|
101 |
+
for (size_t i = 0; i < ty; ++i) {
|
102 |
+
y += y_warpwise[i];
|
103 |
+
}
|
104 |
+
|
105 |
+
block.sync();
|
106 |
+
pipe.consumer_release();
|
107 |
+
}
|
108 |
+
|
109 |
+
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
110 |
+
// final pipeline stage
|
111 |
+
pipe.consumer_wait();
|
112 |
+
block.sync();
|
113 |
+
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
114 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
115 |
+
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
116 |
+
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
117 |
+
float sum = 0.f;
|
118 |
+
#pragma unroll
|
119 |
+
for (size_t i = 0; i < vec_size; ++i) {
|
120 |
+
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
121 |
+
}
|
122 |
+
#pragma unroll
|
123 |
+
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
124 |
+
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
125 |
+
}
|
126 |
+
y_warpwise[threadIdx.y] =
|
127 |
+
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
128 |
+
? sum
|
129 |
+
: 0.f;
|
130 |
+
block.sync();
|
131 |
+
#pragma unroll
|
132 |
+
for (size_t i = 0; i < ty; ++i) {
|
133 |
+
y += y_warpwise[i];
|
134 |
+
}
|
135 |
+
|
136 |
+
block.sync();
|
137 |
+
pipe.consumer_release();
|
138 |
+
|
139 |
+
// write Y;
|
140 |
+
if (block.thread_rank() == 0) {
|
141 |
+
Y[batch_idx * full_y_size + y_offset + j] += static_cast<T>(y);
|
142 |
+
}
|
143 |
+
}
|
144 |
+
|
145 |
+
// nthrs = (2, 16, 4)
|
146 |
+
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
147 |
+
typename T>
|
148 |
+
__global__ void
|
149 |
+
bgmv_expand_kernel(T* __restrict__ Y, const T* __restrict__ X,
|
150 |
+
T** __restrict__ W,
|
151 |
+
const int64_t* __restrict__ indicies, int64_t y_offset,
|
152 |
+
int64_t full_y_size, int64_t layer_idx,
|
153 |
+
float scale) {
|
154 |
+
size_t batch_idx = blockIdx.y;
|
155 |
+
int64_t idx = indicies[batch_idx];
|
156 |
+
|
157 |
+
if (idx < 0) {
|
158 |
+
return;
|
159 |
+
}
|
160 |
+
|
161 |
+
auto block = cg::this_thread_block();
|
162 |
+
size_t tile_idx = blockIdx.x;
|
163 |
+
|
164 |
+
const T* W_ptr = W[idx];
|
165 |
+
|
166 |
+
// load X;
|
167 |
+
flashinfer::vec_t<T, vec_size> x_vec;
|
168 |
+
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
169 |
+
|
170 |
+
// load W;
|
171 |
+
flashinfer::vec_t<T, vec_size> w_vec;
|
172 |
+
w_vec.load(W_ptr + (layer_idx * feat_out + tile_idx * tz * ty) * feat_in +
|
173 |
+
block.thread_rank() * vec_size);
|
174 |
+
|
175 |
+
float sum = 0.f;
|
176 |
+
#pragma unroll
|
177 |
+
for (size_t i = 0; i < vec_size; ++i) {
|
178 |
+
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
179 |
+
}
|
180 |
+
|
181 |
+
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
182 |
+
#pragma unroll
|
183 |
+
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
184 |
+
sum += g.shfl_down(sum, offset);
|
185 |
+
}
|
186 |
+
sum = g.shfl(sum, 0);
|
187 |
+
|
188 |
+
if (threadIdx.x == 0) {
|
189 |
+
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
190 |
+
threadIdx.z * ty + threadIdx.y] += static_cast<T>(sum);
|
191 |
+
}
|
192 |
+
}
|
193 |
+
|
194 |
+
template <int feat_in, int feat_out, typename T>
|
195 |
+
void bgmv_kernel(T* __restrict__ Y, const T* __restrict__ X,
|
196 |
+
T** __restrict__ W,
|
197 |
+
const int64_t* __restrict__ indicies, int64_t y_offset,
|
198 |
+
int64_t full_y_size, int64_t batch_size,
|
199 |
+
int64_t layer_idx, float scale) {
|
200 |
+
constexpr size_t vec_size = 8;
|
201 |
+
constexpr int tz = 4;
|
202 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
203 |
+
|
204 |
+
if constexpr (feat_in < feat_out) {
|
205 |
+
static_assert(feat_in % vec_size == 0);
|
206 |
+
constexpr int tx = feat_in / vec_size;
|
207 |
+
|
208 |
+
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
209 |
+
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
210 |
+
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
211 |
+
|
212 |
+
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
213 |
+
constexpr int ty = 32 / tx;
|
214 |
+
dim3 nblks(feat_out / (ty * tz), batch_size);
|
215 |
+
dim3 nthrs(tx, ty, tz);
|
216 |
+
|
217 |
+
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
218 |
+
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
219 |
+
full_y_size, layer_idx,
|
220 |
+
scale);
|
221 |
+
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
222 |
+
constexpr int ty = 16 / tx;
|
223 |
+
dim3 nblks(feat_out / (ty * tz), batch_size);
|
224 |
+
dim3 nthrs(tx, ty, tz);
|
225 |
+
|
226 |
+
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
227 |
+
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
228 |
+
full_y_size, layer_idx,
|
229 |
+
scale);
|
230 |
+
} else {
|
231 |
+
constexpr int ty = 8 / tx;
|
232 |
+
dim3 nblks(feat_out / (ty * tz), batch_size);
|
233 |
+
dim3 nthrs(tx, ty, tz);
|
234 |
+
|
235 |
+
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
236 |
+
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
237 |
+
full_y_size, layer_idx,
|
238 |
+
scale);
|
239 |
+
}
|
240 |
+
} else {
|
241 |
+
static_assert(feat_in % (vec_size * 32) == 0 ||
|
242 |
+
feat_in % (vec_size * 16) == 0 ||
|
243 |
+
feat_in % (vec_size * 8) == 0);
|
244 |
+
|
245 |
+
if constexpr (feat_in % (vec_size * 32) == 0) {
|
246 |
+
constexpr int tx = 32;
|
247 |
+
constexpr int ty = 4;
|
248 |
+
|
249 |
+
dim3 nblks(feat_out, batch_size);
|
250 |
+
dim3 nthrs(tx, ty);
|
251 |
+
|
252 |
+
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(T),
|
253 |
+
vec_size * sizeof(T), tx, ty, tz>
|
254 |
+
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
255 |
+
full_y_size, layer_idx,
|
256 |
+
scale);
|
257 |
+
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
258 |
+
constexpr int tx = 32;
|
259 |
+
constexpr int ty = 4;
|
260 |
+
|
261 |
+
dim3 nblks(feat_out, batch_size);
|
262 |
+
dim3 nthrs(tx, ty);
|
263 |
+
|
264 |
+
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
265 |
+
vec_size * sizeof(T) / 2,
|
266 |
+
vec_size * sizeof(T) / 2, tx, ty, tz>
|
267 |
+
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
268 |
+
full_y_size, layer_idx,
|
269 |
+
scale);
|
270 |
+
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
271 |
+
constexpr int tx = 16;
|
272 |
+
constexpr int ty = 4;
|
273 |
+
|
274 |
+
dim3 nblks(feat_out, batch_size);
|
275 |
+
dim3 nthrs(tx, ty);
|
276 |
+
|
277 |
+
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
278 |
+
vec_size * sizeof(T) / 2,
|
279 |
+
vec_size * sizeof(T) / 2, tx, ty, tz>
|
280 |
+
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
281 |
+
full_y_size, layer_idx,
|
282 |
+
scale);
|
283 |
+
}
|
284 |
+
}
|
285 |
+
}
|
286 |
+
|
287 |
+
#define INST_BGMV(feat_in, feat_out, T) \
|
288 |
+
template void bgmv_kernel<feat_in, feat_out>( \
|
289 |
+
T* __restrict__ Y, const T* __restrict__ X, \
|
290 |
+
T** __restrict__ W, const int64_t* __restrict__ indicies, \
|
291 |
+
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
292 |
+
int64_t layer_idx, float scale);
|
293 |
+
|
294 |
+
#define INST_BGMV_TWOSIDE(T, narrow, wide) \
|
295 |
+
INST_BGMV(narrow, wide, T) \
|
296 |
+
INST_BGMV(wide, narrow, T)
|
build.toml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "punica_sgmv"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
src = [
|
6 |
+
"torch-ext/torch_binding.cpp",
|
7 |
+
"torch-ext/torch_binding.h"
|
8 |
+
]
|
9 |
+
|
10 |
+
[kernel.sgmv]
|
11 |
+
language = "cuda"
|
12 |
+
src = [
|
13 |
+
"sgmv/sgmv_cutlass.cu",
|
14 |
+
"sgmv/sgmv_cutlass.cuh",
|
15 |
+
]
|
16 |
+
depends = [ "cutlass_3_8", "torch" ]
|
17 |
+
|
18 |
+
[kernel.sgmv_flashinfer]
|
19 |
+
language = "cuda"
|
20 |
+
cuda-capabilities = [ "8.0", "8.6", "8.7", "8.9", "9.0", "10.0", "10.1", "12.0" ]
|
21 |
+
src = [
|
22 |
+
"flashinfer/cp_async.cuh",
|
23 |
+
"flashinfer/mma.cuh",
|
24 |
+
"flashinfer/permuted_smem.cuh",
|
25 |
+
"flashinfer/vec_dtypes.cuh",
|
26 |
+
"sgmv_flashinfer/sgmv_all.cu",
|
27 |
+
"sgmv_flashinfer/sgmv_config.h",
|
28 |
+
"sgmv_flashinfer/sgmv_flashinfer.cuh"
|
29 |
+
]
|
30 |
+
include = [ "." ]
|
31 |
+
depends = [ "torch" ]
|
32 |
+
|
33 |
+
[kernel.bgmv]
|
34 |
+
language = "cuda"
|
35 |
+
src = [
|
36 |
+
"bgmv/bgmv_all.cu",
|
37 |
+
"bgmv/bgmv_impl.cuh",
|
38 |
+
"bgmv/bgmv_config.h",
|
39 |
+
"flashinfer/vec_dtypes.cuh"
|
40 |
+
]
|
41 |
+
include = [ "." ]
|
42 |
+
depends = [ "torch" ]
|
43 |
+
|
44 |
+
[kernel.punica_kernels]
|
45 |
+
language = "cuda"
|
46 |
+
src = [
|
47 |
+
"bgmv/bgmv_config.h",
|
48 |
+
"punica_kernels/punica_ops.cc",
|
49 |
+
"sgmv/sgmv.h",
|
50 |
+
"sgmv_flashinfer/sgmv_config.h"
|
51 |
+
]
|
52 |
+
include = [ "." ]
|
53 |
+
depends = [ "torch" ]
|
flake.lock
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1733328505,
|
6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-utils": {
|
19 |
+
"inputs": {
|
20 |
+
"systems": "systems"
|
21 |
+
},
|
22 |
+
"locked": {
|
23 |
+
"lastModified": 1731533236,
|
24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
25 |
+
"owner": "numtide",
|
26 |
+
"repo": "flake-utils",
|
27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
28 |
+
"type": "github"
|
29 |
+
},
|
30 |
+
"original": {
|
31 |
+
"owner": "numtide",
|
32 |
+
"repo": "flake-utils",
|
33 |
+
"type": "github"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"kernel-builder": {
|
37 |
+
"inputs": {
|
38 |
+
"flake-compat": "flake-compat",
|
39 |
+
"flake-utils": "flake-utils",
|
40 |
+
"nixpkgs": "nixpkgs",
|
41 |
+
"rocm-nix": "rocm-nix"
|
42 |
+
},
|
43 |
+
"locked": {
|
44 |
+
"lastModified": 1747143871,
|
45 |
+
"narHash": "sha256-gXYPmA7wBqcTy1+39Z/UAIZ5mCSl9W5IoAvDQhIezec=",
|
46 |
+
"owner": "huggingface",
|
47 |
+
"repo": "kernel-builder",
|
48 |
+
"rev": "a78a83cfb31373e0782921999e1917b7f91af7d3",
|
49 |
+
"type": "github"
|
50 |
+
},
|
51 |
+
"original": {
|
52 |
+
"owner": "huggingface",
|
53 |
+
"repo": "kernel-builder",
|
54 |
+
"type": "github"
|
55 |
+
}
|
56 |
+
},
|
57 |
+
"nixpkgs": {
|
58 |
+
"locked": {
|
59 |
+
"lastModified": 1746711195,
|
60 |
+
"narHash": "sha256-bSpM2ySq12PBOVN7jZdzXsc99iRoYOyolh5wz43+CjQ=",
|
61 |
+
"owner": "danieldk",
|
62 |
+
"repo": "nixpkgs",
|
63 |
+
"rev": "6b7a66b06ccb09ac95872ac6ddf952e0660672ab",
|
64 |
+
"type": "github"
|
65 |
+
},
|
66 |
+
"original": {
|
67 |
+
"owner": "danieldk",
|
68 |
+
"ref": "kernel-builder-cuda-12.9.0",
|
69 |
+
"repo": "nixpkgs",
|
70 |
+
"type": "github"
|
71 |
+
}
|
72 |
+
},
|
73 |
+
"rocm-nix": {
|
74 |
+
"inputs": {
|
75 |
+
"nixpkgs": [
|
76 |
+
"kernel-builder",
|
77 |
+
"nixpkgs"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
"locked": {
|
81 |
+
"lastModified": 1745310663,
|
82 |
+
"narHash": "sha256-1U3PzCO/jt7HUlEgLOY3RpxadKwTo6GSvb2j4m0UFw0=",
|
83 |
+
"owner": "huggingface",
|
84 |
+
"repo": "rocm-nix",
|
85 |
+
"rev": "e08373a0efa1c297b0c57af070e0a311df47481f",
|
86 |
+
"type": "github"
|
87 |
+
},
|
88 |
+
"original": {
|
89 |
+
"owner": "huggingface",
|
90 |
+
"repo": "rocm-nix",
|
91 |
+
"type": "github"
|
92 |
+
}
|
93 |
+
},
|
94 |
+
"root": {
|
95 |
+
"inputs": {
|
96 |
+
"kernel-builder": "kernel-builder"
|
97 |
+
}
|
98 |
+
},
|
99 |
+
"systems": {
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1681028828,
|
102 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
103 |
+
"owner": "nix-systems",
|
104 |
+
"repo": "default",
|
105 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "nix-systems",
|
110 |
+
"repo": "default",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
}
|
114 |
+
},
|
115 |
+
"root": "root",
|
116 |
+
"version": 7
|
117 |
+
}
|
flake.nix
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Punica SGMV kernel";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs =
|
9 |
+
{
|
10 |
+
self,
|
11 |
+
kernel-builder,
|
12 |
+
}:
|
13 |
+
kernel-builder.lib.genFlakeOutputs {
|
14 |
+
path = ./.;
|
15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
16 |
+
};
|
17 |
+
}
|
flashinfer/cp_async.cuh
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2023 by FlashInfer team.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
#ifndef FLASHINFER_CP_ASYNC_CUH_
|
17 |
+
#define FLASHINFER_CP_ASYNC_CUH_
|
18 |
+
|
19 |
+
#include <cuda_runtime.h>
|
20 |
+
|
21 |
+
namespace flashinfer {
|
22 |
+
|
23 |
+
namespace cp_async {
|
24 |
+
|
25 |
+
enum class SharedMemFillMode {
|
26 |
+
kFillZero, // Fill zero to shared memory when predicate is false
|
27 |
+
kNoFill // Do not fill zero to shared memory when predicate is false
|
28 |
+
};
|
29 |
+
|
30 |
+
enum class PrefetchMode {
|
31 |
+
kNoPrefetch, // Do not fetch additional data from global memory to L2
|
32 |
+
kPrefetch // Fetch additional data from global memory to L2
|
33 |
+
};
|
34 |
+
|
35 |
+
#if (__CUDACC_VER_MAJOR__ >= 11)
|
36 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
|
37 |
+
#define FLASHINFER_CP_ASYNC_ENABLED
|
38 |
+
#endif
|
39 |
+
#endif
|
40 |
+
|
41 |
+
/*!
|
42 |
+
* \brief Wrapper of PTX cp.async.commit_group instruction, commit all prior uncommitted
|
43 |
+
* cp.async instructions to a group
|
44 |
+
*/
|
45 |
+
__device__ __forceinline__ void commit_group() {
|
46 |
+
#ifdef FLASHINFER_CP_ASYNC_ENABLED
|
47 |
+
asm volatile("cp.async.commit_group;\n" ::);
|
48 |
+
#endif
|
49 |
+
}
|
50 |
+
|
51 |
+
/*!
|
52 |
+
* \brief Wrapper of PTX cp.async.wait_group instruction
|
53 |
+
* \tparam n Wait till most recent n groups are committed
|
54 |
+
*/
|
55 |
+
template <size_t n>
|
56 |
+
__device__ __forceinline__ void wait_group() {
|
57 |
+
#ifdef FLASHINFER_CP_ASYNC_ENABLED
|
58 |
+
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
59 |
+
#endif
|
60 |
+
}
|
61 |
+
|
62 |
+
/*!
|
63 |
+
* \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from
|
64 |
+
* global memory to shared memory
|
65 |
+
* \tparam prefetch_mode Whether to fetch additional data from global memory to L2
|
66 |
+
* \tparam T Data type
|
67 |
+
* \param smem_ptr Pointer to shared memory
|
68 |
+
* \param gmem_ptr Pointer to global memory
|
69 |
+
*/
|
70 |
+
template <PrefetchMode prefetch_mode, typename T>
|
71 |
+
__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) {
|
72 |
+
#ifdef FLASHINFER_CP_ASYNC_ENABLED
|
73 |
+
uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
74 |
+
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
75 |
+
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
76 |
+
"l"(gmem_ptr), "n"(16), "r"(16));
|
77 |
+
} else {
|
78 |
+
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
79 |
+
"l"(gmem_ptr), "n"(16), "r"(16));
|
80 |
+
}
|
81 |
+
#else
|
82 |
+
*((uint4*)smem_ptr) = *((uint4*)gmem_ptr);
|
83 |
+
#endif
|
84 |
+
}
|
85 |
+
|
86 |
+
/*!
|
87 |
+
* \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from
|
88 |
+
* global memory to shared memory with predicate.
|
89 |
+
* \tparam prefetch_mode Whether to fetch additional data from global memory to L2
|
90 |
+
* \tparam fill_mode Whether to fill zero to shared memory when predicate is false
|
91 |
+
* \tparam T Data type
|
92 |
+
* \param smem_ptr Pointer to shared memory
|
93 |
+
* \param gmem_ptr Pointer to global memory
|
94 |
+
* \param predicate Predicate value
|
95 |
+
* \note fill zero is slower than not fill zero
|
96 |
+
*/
|
97 |
+
template <PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
98 |
+
__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) {
|
99 |
+
#ifdef FLASHINFER_CP_ASYNC_ENABLED
|
100 |
+
uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
101 |
+
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
102 |
+
int src_in_bytes = predicate ? 16 : 0;
|
103 |
+
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
104 |
+
asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
105 |
+
"l"(gmem_ptr), "n"(16), "r"(src_in_bytes));
|
106 |
+
} else {
|
107 |
+
asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr),
|
108 |
+
"l"(gmem_ptr), "n"(16), "r"(src_in_bytes));
|
109 |
+
}
|
110 |
+
} else {
|
111 |
+
if constexpr (prefetch_mode == PrefetchMode::kPrefetch) {
|
112 |
+
asm volatile(
|
113 |
+
"{\n"
|
114 |
+
" .reg .pred p;\n"
|
115 |
+
" setp.ne.b32 p, %0, 0;\n"
|
116 |
+
" @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n"
|
117 |
+
"}\n" ::"r"((int)predicate),
|
118 |
+
"r"(smem_int_ptr), "l"(gmem_ptr), "n"(16));
|
119 |
+
} else {
|
120 |
+
asm volatile(
|
121 |
+
"{\n"
|
122 |
+
" .reg .pred p;\n"
|
123 |
+
" setp.ne.b32 p, %0, 0;\n"
|
124 |
+
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
125 |
+
"}\n" ::"r"((int)predicate),
|
126 |
+
"r"(smem_int_ptr), "l"(gmem_ptr), "n"(16));
|
127 |
+
}
|
128 |
+
}
|
129 |
+
#else
|
130 |
+
if (predicate) {
|
131 |
+
*((uint4*)smem_ptr) = *((uint4*)gmem_ptr);
|
132 |
+
} else {
|
133 |
+
if constexpr (fill_mode == SharedMemFillMode::kFillZero) {
|
134 |
+
*((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0);
|
135 |
+
}
|
136 |
+
}
|
137 |
+
#endif
|
138 |
+
}
|
139 |
+
|
140 |
+
/*!
|
141 |
+
* \brief Load specified number of bits per thread from global memory to shared memory
|
142 |
+
* \tparam num_bits Number of bits to load, must be 128 or 256
|
143 |
+
* \tparam prefetch_mode Whether to fetch additional data from global memory to L2
|
144 |
+
* \tparam T Data type
|
145 |
+
* \param smem_ptr Pointer to shared memory
|
146 |
+
* \param gmem_ptr Pointer to global memory
|
147 |
+
*/
|
148 |
+
template <size_t num_bits, PrefetchMode prefetch_mode, typename T>
|
149 |
+
__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) {
|
150 |
+
static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256");
|
151 |
+
if constexpr (num_bits == 128) {
|
152 |
+
load_128b<prefetch_mode>(smem_ptr, gmem_ptr);
|
153 |
+
} else {
|
154 |
+
load_128b<prefetch_mode>(smem_ptr, gmem_ptr);
|
155 |
+
load_128b<prefetch_mode>(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T));
|
156 |
+
}
|
157 |
+
}
|
158 |
+
|
159 |
+
/*!
|
160 |
+
* \brief Load specified number of bits per thread from global memory to shared memory with
|
161 |
+
* predicate
|
162 |
+
* \tparam num_bits Number of bits to load, must be 128 or 256
|
163 |
+
* \tparam prefetch_mode Whether to fetch additional data from global memory to L2
|
164 |
+
* \tparam fill_mode Whether to fill zero to shared memory when predicate is false
|
165 |
+
* \tparam T Data type
|
166 |
+
* \param smem_ptr Pointer to shared memory
|
167 |
+
* \param gmem_ptr Pointer to global memory
|
168 |
+
* \param predicate Predicate value
|
169 |
+
* \note fill zero is slower than not fill zero
|
170 |
+
*/
|
171 |
+
template <size_t num_bits, PrefetchMode prefetch_mode, SharedMemFillMode fill_mode, typename T>
|
172 |
+
__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) {
|
173 |
+
static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256");
|
174 |
+
if constexpr (num_bits == 128) {
|
175 |
+
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr, gmem_ptr, predicate);
|
176 |
+
} else {
|
177 |
+
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr, gmem_ptr, predicate);
|
178 |
+
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T),
|
179 |
+
predicate);
|
180 |
+
}
|
181 |
+
}
|
182 |
+
|
183 |
+
} // namespace cp_async
|
184 |
+
|
185 |
+
} // namespace flashinfer
|
186 |
+
|
187 |
+
#endif // FLASHINFER_CP_ASYNC_CUH_
|
flashinfer/mma.cuh
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2023 by FlashInfer team.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
#ifndef FLASHINFER_MMA_CUH_
|
17 |
+
#define FLASHINFER_MMA_CUH_
|
18 |
+
|
19 |
+
#include <cuda_bf16.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_runtime.h>
|
22 |
+
|
23 |
+
#include <type_traits>
|
24 |
+
|
25 |
+
namespace flashinfer {
|
26 |
+
|
27 |
+
namespace mma {
|
28 |
+
|
29 |
+
#if (__CUDACC_VER_MAJOR__ >= 11)
|
30 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
|
31 |
+
#define FLASHINFER_STMATRIX_M8N8X4_ENABLED
|
32 |
+
#endif
|
33 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
|
34 |
+
#define FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED
|
35 |
+
#define FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED
|
36 |
+
#endif
|
37 |
+
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750))
|
38 |
+
#define FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED
|
39 |
+
#define FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED
|
40 |
+
#define FLASHINFER_LDMATRIX_M8N8X4_ENABLED
|
41 |
+
#endif
|
42 |
+
#endif
|
43 |
+
|
44 |
+
enum class MMAMode {
|
45 |
+
kInit = 0U,
|
46 |
+
kInplaceUpdate = 1U,
|
47 |
+
};
|
48 |
+
|
49 |
+
/*!
|
50 |
+
* \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory
|
51 |
+
* to fragment
|
52 |
+
* \tparam T data type of the fragment
|
53 |
+
* \param R pointer to the fragment
|
54 |
+
* \param smem_ptr pointer to the shared memory
|
55 |
+
*/
|
56 |
+
template <typename T>
|
57 |
+
__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
|
58 |
+
#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED
|
59 |
+
uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
60 |
+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
61 |
+
: "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3])
|
62 |
+
: "r"(smem_int_ptr));
|
63 |
+
#else
|
64 |
+
#error "Unsupported CUDA architecture for ldmatrix instruction"
|
65 |
+
#endif
|
66 |
+
}
|
67 |
+
|
68 |
+
/*!
|
69 |
+
* \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from
|
70 |
+
* shared memory to fragment and transposes the fragment
|
71 |
+
* \tparam T data type of the fragment
|
72 |
+
* \param R pointer to the fragment
|
73 |
+
* \param smem_ptr pointer to the shared memory
|
74 |
+
*/
|
75 |
+
template <typename T>
|
76 |
+
__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) {
|
77 |
+
#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED
|
78 |
+
uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
79 |
+
asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
|
80 |
+
: "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3])
|
81 |
+
: "r"(smem_int_ptr));
|
82 |
+
#else
|
83 |
+
#error "Unsupported CUDA architecture for ldmatrix instruction"
|
84 |
+
#endif
|
85 |
+
}
|
86 |
+
|
87 |
+
/*!
|
88 |
+
* \brief Wrapper of PTX stmatrix m8n8.x4 instruction, stores data from fragment
|
89 |
+
* to shared memory
|
90 |
+
* \tparam T data type of the fragment
|
91 |
+
* \param R pointer to the fragment
|
92 |
+
* \param smem_ptr pointer to the shared memory
|
93 |
+
*/
|
94 |
+
template <typename T>
|
95 |
+
__device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) {
|
96 |
+
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
|
97 |
+
uint32_t smem_int_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
98 |
+
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n"
|
99 |
+
:
|
100 |
+
: "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3]));
|
101 |
+
#else
|
102 |
+
// Fallback implementation, slower than PTX instruction
|
103 |
+
const uint32_t tx = threadIdx.x;
|
104 |
+
uint4 word;
|
105 |
+
#pragma unroll
|
106 |
+
for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) {
|
107 |
+
word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4);
|
108 |
+
word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1);
|
109 |
+
word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2);
|
110 |
+
word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3);
|
111 |
+
if (tx / 8 == reg_id) {
|
112 |
+
*(uint4*)smem_ptr = word;
|
113 |
+
}
|
114 |
+
}
|
115 |
+
#endif
|
116 |
+
}
|
117 |
+
|
118 |
+
/*!
|
119 |
+
* \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix
|
120 |
+
* multiplication, accumulated in f32.
|
121 |
+
* \tparam T data type of the fragment
|
122 |
+
* \tparam mma_mode whether we are initializing the accumulator or updating it
|
123 |
+
* \param C pointer to the accumulator
|
124 |
+
* \param A pointer to the fragment of matrix A
|
125 |
+
* \param B pointer to the fragment of matrix B
|
126 |
+
*/
|
127 |
+
template <typename T, MMAMode mma_mode = MMAMode::kInplaceUpdate>
|
128 |
+
__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A,
|
129 |
+
uint32_t* B) {
|
130 |
+
#if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED)
|
131 |
+
if constexpr (mma_mode == MMAMode::kInit) {
|
132 |
+
if constexpr (std::is_same<T, half>::value) {
|
133 |
+
asm volatile(
|
134 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
135 |
+
"{%0, %1, %2, %3},"
|
136 |
+
"{%4, %5, %6, %7},"
|
137 |
+
"{%8, %9},"
|
138 |
+
"{%10, %11, %12, %13};\n"
|
139 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
140 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f),
|
141 |
+
"f"(0.f), "f"(0.f));
|
142 |
+
asm volatile(
|
143 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
144 |
+
"{%0, %1, %2, %3},"
|
145 |
+
"{%4, %5, %6, %7},"
|
146 |
+
"{%8, %9},"
|
147 |
+
"{%10, %11, %12, %13};\n"
|
148 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
149 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f),
|
150 |
+
"f"(0.f), "f"(0.f));
|
151 |
+
} else {
|
152 |
+
asm volatile(
|
153 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
154 |
+
"{%0, %1, %2, %3},"
|
155 |
+
"{%4, %5, %6, %7},"
|
156 |
+
"{%8, %9},"
|
157 |
+
"{%10, %11, %12, %13};\n"
|
158 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
159 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f),
|
160 |
+
"f"(0.f), "f"(0.f));
|
161 |
+
asm volatile(
|
162 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
163 |
+
"{%0, %1, %2, %3},"
|
164 |
+
"{%4, %5, %6, %7},"
|
165 |
+
"{%8, %9},"
|
166 |
+
"{%10, %11, %12, %13};\n"
|
167 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
168 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f),
|
169 |
+
"f"(0.f), "f"(0.f));
|
170 |
+
}
|
171 |
+
} else {
|
172 |
+
if constexpr (std::is_same<T, half>::value) {
|
173 |
+
asm volatile(
|
174 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
175 |
+
"{%0, %1, %2, %3},"
|
176 |
+
"{%4, %5, %6, %7},"
|
177 |
+
"{%8, %9},"
|
178 |
+
"{%10, %11, %12, %13};\n"
|
179 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
180 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
|
181 |
+
"f"(C[2]), "f"(C[3]));
|
182 |
+
asm volatile(
|
183 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
|
184 |
+
"{%0, %1, %2, %3},"
|
185 |
+
"{%4, %5, %6, %7},"
|
186 |
+
"{%8, %9},"
|
187 |
+
"{%10, %11, %12, %13};\n"
|
188 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
189 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
|
190 |
+
"f"(C[6]), "f"(C[7]));
|
191 |
+
} else {
|
192 |
+
asm volatile(
|
193 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
194 |
+
"{%0, %1, %2, %3},"
|
195 |
+
"{%4, %5, %6, %7},"
|
196 |
+
"{%8, %9},"
|
197 |
+
"{%10, %11, %12, %13};\n"
|
198 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
199 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]),
|
200 |
+
"f"(C[2]), "f"(C[3]));
|
201 |
+
asm volatile(
|
202 |
+
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
|
203 |
+
"{%0, %1, %2, %3},"
|
204 |
+
"{%4, %5, %6, %7},"
|
205 |
+
"{%8, %9},"
|
206 |
+
"{%10, %11, %12, %13};\n"
|
207 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
208 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]),
|
209 |
+
"f"(C[6]), "f"(C[7]));
|
210 |
+
}
|
211 |
+
}
|
212 |
+
#elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED)
|
213 |
+
static_assert(std::is_same<T, half>::value, "bf16 mma instruction is not supported on sm_75");
|
214 |
+
if constexpr (mma_mode == MMAMode::kInit) {
|
215 |
+
asm volatile(
|
216 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
217 |
+
"{%0, %1, %2, %3},"
|
218 |
+
"{%4, %5},"
|
219 |
+
"{%6},"
|
220 |
+
"{%7, %8, %9, %10};\n"
|
221 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
222 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
|
223 |
+
asm volatile(
|
224 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
225 |
+
"{%0, %1, %2, %3},"
|
226 |
+
"{%4, %5},"
|
227 |
+
"{%6},"
|
228 |
+
"{%7, %8, %9, %10};\n"
|
229 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
230 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
|
231 |
+
asm volatile(
|
232 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
233 |
+
"{%0, %1, %2, %3},"
|
234 |
+
"{%4, %5},"
|
235 |
+
"{%6},"
|
236 |
+
"{%7, %8, %9, %10};\n"
|
237 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
238 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
|
239 |
+
asm volatile(
|
240 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
241 |
+
"{%0, %1, %2, %3},"
|
242 |
+
"{%4, %5},"
|
243 |
+
"{%6},"
|
244 |
+
"{%7, %8, %9, %10};\n"
|
245 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
246 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), "f"(0.f));
|
247 |
+
} else {
|
248 |
+
asm volatile(
|
249 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
250 |
+
"{%0, %1, %2, %3},"
|
251 |
+
"{%4, %5},"
|
252 |
+
"{%6},"
|
253 |
+
"{%7, %8, %9, %10};\n"
|
254 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
255 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
256 |
+
asm volatile(
|
257 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
258 |
+
"{%0, %1, %2, %3},"
|
259 |
+
"{%4, %5},"
|
260 |
+
"{%6},"
|
261 |
+
"{%7, %8, %9, %10};\n"
|
262 |
+
: "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3])
|
263 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]));
|
264 |
+
asm volatile(
|
265 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
266 |
+
"{%0, %1, %2, %3},"
|
267 |
+
"{%4, %5},"
|
268 |
+
"{%6},"
|
269 |
+
"{%7, %8, %9, %10};\n"
|
270 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
271 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7]));
|
272 |
+
asm volatile(
|
273 |
+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 "
|
274 |
+
"{%0, %1, %2, %3},"
|
275 |
+
"{%4, %5},"
|
276 |
+
"{%6},"
|
277 |
+
"{%7, %8, %9, %10};\n"
|
278 |
+
: "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7])
|
279 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), "f"(C[5]), "f"(C[6]), "f"(C[7]));
|
280 |
+
}
|
281 |
+
#else
|
282 |
+
#error "Unsupported CUDA architecture for mma instruction"
|
283 |
+
#endif
|
284 |
+
}
|
285 |
+
|
286 |
+
/*!
|
287 |
+
* \brief Wrapper of two mma m16n8k16 instructions for row major and column major f16 matrix
|
288 |
+
* multiplication, accumulated in f16.
|
289 |
+
* \tparam mma_mode whether we are initializing the accumulator or updating it
|
290 |
+
* \param C pointer to the accumulator
|
291 |
+
* \param A pointer to the fragment of matrix A
|
292 |
+
* \param B pointer to the fragment of matrix B
|
293 |
+
*/
|
294 |
+
template <MMAMode mma_mode = MMAMode::kInplaceUpdate>
|
295 |
+
__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C, uint32_t* A,
|
296 |
+
uint32_t* B) {
|
297 |
+
#if defined(FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED)
|
298 |
+
if constexpr (mma_mode == MMAMode::kInit) {
|
299 |
+
asm volatile(
|
300 |
+
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
301 |
+
"{%0, %1},"
|
302 |
+
"{%2, %3, %4, %5},"
|
303 |
+
"{%6, %7},"
|
304 |
+
"{%8, %9};\n"
|
305 |
+
: "=r"(C[0]), "=r"(C[1])
|
306 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0));
|
307 |
+
asm volatile(
|
308 |
+
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
309 |
+
"{%0, %1},"
|
310 |
+
"{%2, %3, %4, %5},"
|
311 |
+
"{%6, %7},"
|
312 |
+
"{%8, %9};\n"
|
313 |
+
: "=r"(C[2]), "=r"(C[3])
|
314 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0));
|
315 |
+
} else {
|
316 |
+
asm volatile(
|
317 |
+
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
318 |
+
"{%0, %1},"
|
319 |
+
"{%2, %3, %4, %5},"
|
320 |
+
"{%6, %7},"
|
321 |
+
"{%8, %9};\n"
|
322 |
+
: "=r"(C[0]), "=r"(C[1])
|
323 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]));
|
324 |
+
asm volatile(
|
325 |
+
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
|
326 |
+
"{%0, %1},"
|
327 |
+
"{%2, %3, %4, %5},"
|
328 |
+
"{%6, %7},"
|
329 |
+
"{%8, %9};\n"
|
330 |
+
: "=r"(C[2]), "=r"(C[3])
|
331 |
+
: "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[2]), "r"(C[3]));
|
332 |
+
}
|
333 |
+
#elif defined(FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED)
|
334 |
+
if constexpr (mma_mode == MMAMode::kInit) {
|
335 |
+
asm volatile(
|
336 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
337 |
+
"{%0, %1},"
|
338 |
+
"{%2, %3},"
|
339 |
+
"{%4},"
|
340 |
+
"{%5, %6};\n"
|
341 |
+
: "=r"(C[0]), "=r"(C[1])
|
342 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(0), "r"(0));
|
343 |
+
asm volatile(
|
344 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
345 |
+
"{%0, %1},"
|
346 |
+
"{%2, %3},"
|
347 |
+
"{%4},"
|
348 |
+
"{%5, %6};\n"
|
349 |
+
: "=r"(C[0]), "=r"(C[1])
|
350 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(0), "r"(0));
|
351 |
+
asm volatile(
|
352 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
353 |
+
"{%0, %1},"
|
354 |
+
"{%2, %3},"
|
355 |
+
"{%4},"
|
356 |
+
"{%5, %6};\n"
|
357 |
+
: "=r"(C[2]), "=r"(C[3])
|
358 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(0), "r"(0));
|
359 |
+
asm volatile(
|
360 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
361 |
+
"{%0, %1},"
|
362 |
+
"{%2, %3},"
|
363 |
+
"{%4},"
|
364 |
+
"{%5, %6};\n"
|
365 |
+
: "=r"(C[2]), "=r"(C[3])
|
366 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(0), "r"(0));
|
367 |
+
} else {
|
368 |
+
asm volatile(
|
369 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
370 |
+
"{%0, %1},"
|
371 |
+
"{%2, %3},"
|
372 |
+
"{%4},"
|
373 |
+
"{%5, %6};\n"
|
374 |
+
: "=r"(C[0]), "=r"(C[1])
|
375 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1]));
|
376 |
+
asm volatile(
|
377 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
378 |
+
"{%0, %1},"
|
379 |
+
"{%2, %3},"
|
380 |
+
"{%4},"
|
381 |
+
"{%5, %6};\n"
|
382 |
+
: "=r"(C[0]), "=r"(C[1])
|
383 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(C[0]), "r"(C[1]));
|
384 |
+
asm volatile(
|
385 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
386 |
+
"{%0, %1},"
|
387 |
+
"{%2, %3},"
|
388 |
+
"{%4},"
|
389 |
+
"{%5, %6};\n"
|
390 |
+
: "=r"(C[2]), "=r"(C[3])
|
391 |
+
: "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(C[2]), "r"(C[3]));
|
392 |
+
asm volatile(
|
393 |
+
"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 "
|
394 |
+
"{%0, %1},"
|
395 |
+
"{%2, %3},"
|
396 |
+
"{%4},"
|
397 |
+
"{%5, %6};\n"
|
398 |
+
: "=r"(C[2]), "=r"(C[3])
|
399 |
+
: "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(C[2]), "r"(C[3]));
|
400 |
+
}
|
401 |
+
#else
|
402 |
+
#error "Unsupported CUDA architecture for mma instruction"
|
403 |
+
#endif
|
404 |
+
}
|
405 |
+
|
406 |
+
} // namespace mma
|
407 |
+
|
408 |
+
} // namespace flashinfer
|
409 |
+
|
410 |
+
#endif // FLASHINFER_MMA_CUH_
|
flashinfer/permuted_smem.cuh
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2023 by FlashInfer team.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
#ifndef FLASHINFER_PERMUTED_SMEM_CUH_
|
17 |
+
#define FLASHINFER_PERMUTED_SMEM_CUH_
|
18 |
+
|
19 |
+
#include <cuda_bf16.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#include <cuda_runtime.h>
|
22 |
+
|
23 |
+
#include <cuda/pipeline>
|
24 |
+
|
25 |
+
#include "cp_async.cuh"
|
26 |
+
#include "mma.cuh"
|
27 |
+
|
28 |
+
namespace flashinfer {
|
29 |
+
|
30 |
+
// Each cell is 4 bytes.
|
31 |
+
using cell_t = uint4;
|
32 |
+
|
33 |
+
/*!
|
34 |
+
* \brief Compute the number of elements that can be stored in a cell.
|
35 |
+
* \tparam T The data type of the elements.
|
36 |
+
*/
|
37 |
+
template <typename T>
|
38 |
+
constexpr __host__ __device__ __forceinline__ uint32_t cell_capacity() {
|
39 |
+
return sizeof(cell_t) / sizeof(T);
|
40 |
+
}
|
41 |
+
|
42 |
+
/*!
|
43 |
+
* \brief The shared memory wrapper.
|
44 |
+
*/
|
45 |
+
struct smem_t {
|
46 |
+
// The base pointer.
|
47 |
+
cell_t* base;
|
48 |
+
__device__ __forceinline__ smem_t() : base(nullptr) {}
|
49 |
+
template <typename T>
|
50 |
+
__device__ __forceinline__ smem_t(T* base) : base((cell_t*)base) {}
|
51 |
+
|
52 |
+
/*!
|
53 |
+
* \brief Compute the element offset given coordinates in a permuted shared memory.
|
54 |
+
* \tparam stride The stride (in terms of cells) in the permuted shared memory.
|
55 |
+
* \param i The row index.
|
56 |
+
* \param j The column index.
|
57 |
+
*/
|
58 |
+
template <uint32_t stride>
|
59 |
+
static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) {
|
60 |
+
return (i / 2) * stride * 2 + (j / 4) * 8 + (i % 2) * 4 + ((j % 4) ^ ((i / 2) % 4));
|
61 |
+
}
|
62 |
+
|
63 |
+
__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, uint32_t* R) {
|
64 |
+
cell_t* smem_ptr = base + offset;
|
65 |
+
mma::ldmatrix_m8n8x4(R, smem_ptr);
|
66 |
+
}
|
67 |
+
__device__ __forceinline__ void stmatrix_m8n8x4(uint32_t offset, uint32_t* R) {
|
68 |
+
cell_t* smem_ptr = base + offset;
|
69 |
+
mma::stmatrix_m8n8x4(R, smem_ptr);
|
70 |
+
}
|
71 |
+
__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, uint32_t* R) {
|
72 |
+
cell_t* smem_ptr = base + offset;
|
73 |
+
mma::ldmatrix_m8n8x4_trans(R, smem_ptr);
|
74 |
+
}
|
75 |
+
template <cp_async::SharedMemFillMode fill_mode, typename T>
|
76 |
+
__device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr, bool predicate) {
|
77 |
+
cell_t* smem_ptr = base + offset;
|
78 |
+
cp_async::pred_load_128b<cp_async::PrefetchMode::kPrefetch, fill_mode>(
|
79 |
+
smem_ptr, reinterpret_cast<const cell_t*>(gptr), predicate);
|
80 |
+
}
|
81 |
+
template <typename T>
|
82 |
+
__device__ __forceinline__ void load_128b_async(uint32_t offset, const T* gptr) {
|
83 |
+
cell_t* smem_ptr = base + offset;
|
84 |
+
cp_async::load_128b<cp_async::PrefetchMode::kPrefetch>(smem_ptr,
|
85 |
+
reinterpret_cast<const cell_t*>(gptr));
|
86 |
+
}
|
87 |
+
template <typename T>
|
88 |
+
__device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) {
|
89 |
+
*reinterpret_cast<cell_t*>(gptr) = *(base + offset);
|
90 |
+
}
|
91 |
+
};
|
92 |
+
|
93 |
+
} // namespace flashinfer
|
94 |
+
|
95 |
+
#endif // FLASHINFER_PERMUTED_SMEM_CUH_
|
flashinfer/vec_dtypes.cuh
ADDED
@@ -0,0 +1,1262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright (c) 2023 by FlashInfer team.
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
#ifndef VEC_DTYPES_CUH_
|
17 |
+
#define VEC_DTYPES_CUH_
|
18 |
+
|
19 |
+
#include <cuda_bf16.h>
|
20 |
+
#include <cuda_fp16.h>
|
21 |
+
#ifdef FLASHINFER_ENABLE_FP8
|
22 |
+
#include <cuda_fp8.h>
|
23 |
+
#endif
|
24 |
+
#include <cuda_runtime.h>
|
25 |
+
|
26 |
+
#include <type_traits>
|
27 |
+
|
28 |
+
namespace flashinfer {
|
29 |
+
|
30 |
+
#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ __host__
|
31 |
+
|
32 |
+
template <typename float_t, size_t vec_size>
|
33 |
+
struct vec_t {
|
34 |
+
FLASHINFER_INLINE float_t& operator[](size_t i);
|
35 |
+
FLASHINFER_INLINE const float_t& operator[](size_t i) const;
|
36 |
+
FLASHINFER_INLINE void fill(float_t val);
|
37 |
+
FLASHINFER_INLINE void load(const float_t* ptr);
|
38 |
+
FLASHINFER_INLINE void store(float_t* ptr) const;
|
39 |
+
template <typename T>
|
40 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src);
|
41 |
+
template <typename T>
|
42 |
+
FLASHINFER_INLINE void cast_load(const T* ptr);
|
43 |
+
template <typename T>
|
44 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const;
|
45 |
+
FLASHINFER_INLINE static void memcpy(float_t* dst, const float_t* src);
|
46 |
+
FLASHINFER_INLINE float_t* ptr();
|
47 |
+
};
|
48 |
+
|
49 |
+
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
|
50 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<tgt_float_t, vec_size>& dst,
|
51 |
+
const vec_t<src_float_t, vec_size>& src) {
|
52 |
+
#pragma unroll
|
53 |
+
for (size_t i = 0; i < vec_size; ++i) {
|
54 |
+
dst[i] = tgt_float_t(src[i]);
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
|
59 |
+
FLASHINFER_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
|
60 |
+
const src_float_t* src_ptr) {
|
61 |
+
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
|
62 |
+
dst.load(src_ptr);
|
63 |
+
} else {
|
64 |
+
vec_t<src_float_t, vec_size> tmp;
|
65 |
+
tmp.load(src_ptr);
|
66 |
+
dst.cast_from(tmp);
|
67 |
+
}
|
68 |
+
}
|
69 |
+
|
70 |
+
template <typename src_float_t, typename tgt_float_t, size_t vec_size>
|
71 |
+
FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr,
|
72 |
+
const vec_t<src_float_t, vec_size>& src) {
|
73 |
+
if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
|
74 |
+
src.store(dst_ptr);
|
75 |
+
} else {
|
76 |
+
vec_t<tgt_float_t, vec_size> tmp;
|
77 |
+
tmp.cast_from(src);
|
78 |
+
tmp.store(dst_ptr);
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
#ifdef FLASHINFER_ENABLE_FP8
|
83 |
+
/******************* vec_t<__nv_fp8_e4m3> *******************/
|
84 |
+
|
85 |
+
// __nv_fp8_e4m3 x 1
|
86 |
+
template <>
|
87 |
+
struct vec_t<__nv_fp8_e4m3, 1> {
|
88 |
+
__nv_fp8_e4m3 data;
|
89 |
+
|
90 |
+
FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
|
91 |
+
FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
|
92 |
+
return ((const __nv_fp8_e4m3*)(&data))[i];
|
93 |
+
}
|
94 |
+
FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
|
95 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
|
96 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
|
97 |
+
FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
|
98 |
+
template <typename T>
|
99 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
|
100 |
+
cast_from_impl(*this, src);
|
101 |
+
}
|
102 |
+
template <typename T>
|
103 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
104 |
+
cast_load_impl(*this, ptr);
|
105 |
+
}
|
106 |
+
template <typename T>
|
107 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
108 |
+
cast_store_impl(ptr, *this);
|
109 |
+
}
|
110 |
+
|
111 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
|
112 |
+
};
|
113 |
+
|
114 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { data = val; }
|
115 |
+
|
116 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) { data = *ptr; }
|
117 |
+
|
118 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const { *ptr = data; }
|
119 |
+
|
120 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(__nv_fp8_e4m3* dst,
|
121 |
+
const __nv_fp8_e4m3* src) {
|
122 |
+
*dst = *src;
|
123 |
+
}
|
124 |
+
|
125 |
+
// __nv_fp8_e4m3 x 2
|
126 |
+
template <>
|
127 |
+
struct vec_t<__nv_fp8_e4m3, 2> {
|
128 |
+
__nv_fp8x2_e4m3 data;
|
129 |
+
|
130 |
+
FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
|
131 |
+
FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
|
132 |
+
return ((const __nv_fp8_e4m3*)(&data))[i];
|
133 |
+
}
|
134 |
+
FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
|
135 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
|
136 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
|
137 |
+
FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
|
138 |
+
template <typename T>
|
139 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
|
140 |
+
cast_from_impl(*this, src);
|
141 |
+
}
|
142 |
+
template <typename T>
|
143 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
144 |
+
cast_load_impl(*this, ptr);
|
145 |
+
}
|
146 |
+
template <typename T>
|
147 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
148 |
+
cast_store_impl(ptr, *this);
|
149 |
+
}
|
150 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
|
151 |
+
};
|
152 |
+
|
153 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
|
154 |
+
data.__x = (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
|
155 |
+
}
|
156 |
+
|
157 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) {
|
158 |
+
data = *((__nv_fp8x2_e4m3*)ptr);
|
159 |
+
}
|
160 |
+
|
161 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const {
|
162 |
+
*((__nv_fp8x2_e4m3*)ptr) = data;
|
163 |
+
}
|
164 |
+
|
165 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(__nv_fp8_e4m3* dst,
|
166 |
+
const __nv_fp8_e4m3* src) {
|
167 |
+
*((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src);
|
168 |
+
}
|
169 |
+
|
170 |
+
// __nv_fp8_e4m3 x 4
|
171 |
+
|
172 |
+
template <>
|
173 |
+
struct vec_t<__nv_fp8_e4m3, 4> {
|
174 |
+
__nv_fp8x4_e4m3 data;
|
175 |
+
|
176 |
+
FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
|
177 |
+
FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
|
178 |
+
return ((const __nv_fp8_e4m3*)(&data))[i];
|
179 |
+
}
|
180 |
+
FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
|
181 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
|
182 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
|
183 |
+
FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
|
184 |
+
template <typename T>
|
185 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
|
186 |
+
cast_from_impl(*this, src);
|
187 |
+
}
|
188 |
+
template <typename T>
|
189 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
190 |
+
cast_load_impl(*this, ptr);
|
191 |
+
}
|
192 |
+
template <typename T>
|
193 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
194 |
+
cast_store_impl(ptr, *this);
|
195 |
+
}
|
196 |
+
|
197 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
|
198 |
+
};
|
199 |
+
|
200 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
|
201 |
+
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
202 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
203 |
+
}
|
204 |
+
|
205 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) {
|
206 |
+
data = *((__nv_fp8x4_e4m3*)ptr);
|
207 |
+
}
|
208 |
+
|
209 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const {
|
210 |
+
*((__nv_fp8x4_e4m3*)ptr) = data;
|
211 |
+
}
|
212 |
+
|
213 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(__nv_fp8_e4m3* dst,
|
214 |
+
const __nv_fp8_e4m3* src) {
|
215 |
+
*((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src);
|
216 |
+
}
|
217 |
+
|
218 |
+
// __nv_fp8_e4m3 x 8
|
219 |
+
|
220 |
+
template <>
|
221 |
+
struct vec_t<__nv_fp8_e4m3, 8> {
|
222 |
+
uint2 data;
|
223 |
+
|
224 |
+
FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; }
|
225 |
+
FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
|
226 |
+
return ((const __nv_fp8_e4m3*)(&data))[i];
|
227 |
+
}
|
228 |
+
FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
|
229 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val);
|
230 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr);
|
231 |
+
FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const;
|
232 |
+
template <typename T>
|
233 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 8>& src) {
|
234 |
+
cast_from_impl(*this, src);
|
235 |
+
}
|
236 |
+
template <typename T>
|
237 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
238 |
+
cast_load_impl(*this, ptr);
|
239 |
+
}
|
240 |
+
template <typename T>
|
241 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
242 |
+
cast_store_impl(ptr, *this);
|
243 |
+
}
|
244 |
+
|
245 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src);
|
246 |
+
};
|
247 |
+
|
248 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
|
249 |
+
((__nv_fp8x4_e4m3*)(&data.x))->__x =
|
250 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
251 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
252 |
+
((__nv_fp8x4_e4m3*)(&data.y))->__x =
|
253 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
254 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
255 |
+
}
|
256 |
+
|
257 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) {
|
258 |
+
data = *((uint2*)ptr);
|
259 |
+
}
|
260 |
+
|
261 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const {
|
262 |
+
*((uint2*)ptr) = data;
|
263 |
+
}
|
264 |
+
|
265 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(__nv_fp8_e4m3* dst,
|
266 |
+
const __nv_fp8_e4m3* src) {
|
267 |
+
*((uint2*)dst) = *((uint2*)src);
|
268 |
+
}
|
269 |
+
|
270 |
+
// __nv_fp8_e4m3 x 16 or more
|
271 |
+
template <size_t vec_size>
|
272 |
+
struct vec_t<__nv_fp8_e4m3, vec_size> {
|
273 |
+
uint4 data[vec_size / 16];
|
274 |
+
|
275 |
+
FLASHINFER_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)data)[i]; }
|
276 |
+
FLASHINFER_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
|
277 |
+
return ((const __nv_fp8_e4m3*)data)[i];
|
278 |
+
}
|
279 |
+
FLASHINFER_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); }
|
280 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) {
|
281 |
+
#pragma unroll
|
282 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
283 |
+
((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x =
|
284 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
285 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
286 |
+
((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x =
|
287 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
288 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
289 |
+
((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x =
|
290 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
291 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
292 |
+
((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x =
|
293 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
294 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
295 |
+
}
|
296 |
+
}
|
297 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e4m3* ptr) {
|
298 |
+
#pragma unroll
|
299 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
300 |
+
data[i] = ((uint4*)ptr)[i];
|
301 |
+
}
|
302 |
+
}
|
303 |
+
FLASHINFER_INLINE void store(__nv_fp8_e4m3* ptr) const {
|
304 |
+
#pragma unroll
|
305 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
306 |
+
((uint4*)ptr)[i] = data[i];
|
307 |
+
}
|
308 |
+
}
|
309 |
+
template <typename T>
|
310 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
311 |
+
cast_from_impl(*this, src);
|
312 |
+
}
|
313 |
+
template <typename T>
|
314 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
315 |
+
cast_load_impl(*this, ptr);
|
316 |
+
}
|
317 |
+
template <typename T>
|
318 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
319 |
+
cast_store_impl(ptr, *this);
|
320 |
+
}
|
321 |
+
|
322 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
|
323 |
+
#pragma unroll
|
324 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
325 |
+
((uint4*)dst)[i] = ((uint4*)src)[i];
|
326 |
+
}
|
327 |
+
}
|
328 |
+
};
|
329 |
+
|
330 |
+
/******************* vec_t<__nv_fp8_e5m2> *******************/
|
331 |
+
|
332 |
+
// __nv_fp8_e5m2 x 1
|
333 |
+
template <>
|
334 |
+
struct vec_t<__nv_fp8_e5m2, 1> {
|
335 |
+
__nv_fp8_e5m2 data;
|
336 |
+
|
337 |
+
FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
|
338 |
+
FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
|
339 |
+
return ((const __nv_fp8_e5m2*)(&data))[i];
|
340 |
+
}
|
341 |
+
FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
|
342 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
|
343 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
|
344 |
+
FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
|
345 |
+
template <typename T>
|
346 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
|
347 |
+
cast_from_impl(*this, src);
|
348 |
+
}
|
349 |
+
template <typename T>
|
350 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
351 |
+
cast_load_impl(*this, ptr);
|
352 |
+
}
|
353 |
+
template <typename T>
|
354 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
355 |
+
cast_store_impl(ptr, *this);
|
356 |
+
}
|
357 |
+
|
358 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
|
359 |
+
};
|
360 |
+
|
361 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { data = val; }
|
362 |
+
|
363 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) { data = *ptr; }
|
364 |
+
|
365 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const { *ptr = data; }
|
366 |
+
|
367 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(__nv_fp8_e5m2* dst,
|
368 |
+
const __nv_fp8_e5m2* src) {
|
369 |
+
*dst = *src;
|
370 |
+
}
|
371 |
+
|
372 |
+
// __nv_fp8_e5m2 x 2
|
373 |
+
template <>
|
374 |
+
struct vec_t<__nv_fp8_e5m2, 2> {
|
375 |
+
__nv_fp8x2_e5m2 data;
|
376 |
+
|
377 |
+
FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
|
378 |
+
FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
|
379 |
+
return ((const __nv_fp8_e5m2*)(&data))[i];
|
380 |
+
}
|
381 |
+
FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
|
382 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
|
383 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
|
384 |
+
FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
|
385 |
+
template <typename T>
|
386 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
|
387 |
+
cast_from_impl(*this, src);
|
388 |
+
}
|
389 |
+
template <typename T>
|
390 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
391 |
+
cast_load_impl(*this, ptr);
|
392 |
+
}
|
393 |
+
template <typename T>
|
394 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
395 |
+
cast_store_impl(ptr, *this);
|
396 |
+
}
|
397 |
+
|
398 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
|
399 |
+
};
|
400 |
+
|
401 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
|
402 |
+
data.__x = (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
|
403 |
+
}
|
404 |
+
|
405 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) {
|
406 |
+
data = *((__nv_fp8x2_e5m2*)ptr);
|
407 |
+
}
|
408 |
+
|
409 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const {
|
410 |
+
*((__nv_fp8x2_e5m2*)ptr) = data;
|
411 |
+
}
|
412 |
+
|
413 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(__nv_fp8_e5m2* dst,
|
414 |
+
const __nv_fp8_e5m2* src) {
|
415 |
+
*((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src);
|
416 |
+
}
|
417 |
+
|
418 |
+
// __nv_fp8_e5m2 x 4
|
419 |
+
|
420 |
+
template <>
|
421 |
+
struct vec_t<__nv_fp8_e5m2, 4> {
|
422 |
+
__nv_fp8x4_e5m2 data;
|
423 |
+
|
424 |
+
FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
|
425 |
+
FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
|
426 |
+
return ((const __nv_fp8_e5m2*)(&data))[i];
|
427 |
+
}
|
428 |
+
FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
|
429 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
|
430 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
|
431 |
+
FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
|
432 |
+
template <typename T>
|
433 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
|
434 |
+
cast_from_impl(*this, src);
|
435 |
+
}
|
436 |
+
template <typename T>
|
437 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
438 |
+
cast_load_impl(*this, ptr);
|
439 |
+
}
|
440 |
+
template <typename T>
|
441 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
442 |
+
cast_store_impl(ptr, *this);
|
443 |
+
}
|
444 |
+
|
445 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
|
446 |
+
};
|
447 |
+
|
448 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
|
449 |
+
data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
450 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
451 |
+
}
|
452 |
+
|
453 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) {
|
454 |
+
data = *((__nv_fp8x4_e5m2*)ptr);
|
455 |
+
}
|
456 |
+
|
457 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const {
|
458 |
+
*((__nv_fp8x4_e5m2*)ptr) = data;
|
459 |
+
}
|
460 |
+
|
461 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(__nv_fp8_e5m2* dst,
|
462 |
+
const __nv_fp8_e5m2* src) {
|
463 |
+
*((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src);
|
464 |
+
}
|
465 |
+
|
466 |
+
// __nv_fp8_e5m2 x 8
|
467 |
+
|
468 |
+
template <>
|
469 |
+
struct vec_t<__nv_fp8_e5m2, 8> {
|
470 |
+
uint2 data;
|
471 |
+
|
472 |
+
FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; }
|
473 |
+
FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
|
474 |
+
return ((const __nv_fp8_e5m2*)(&data))[i];
|
475 |
+
}
|
476 |
+
FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
|
477 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val);
|
478 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr);
|
479 |
+
FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const;
|
480 |
+
template <typename T>
|
481 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 8>& src) {
|
482 |
+
cast_from_impl(*this, src);
|
483 |
+
}
|
484 |
+
template <typename T>
|
485 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
486 |
+
cast_load_impl(*this, ptr);
|
487 |
+
}
|
488 |
+
template <typename T>
|
489 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
490 |
+
cast_store_impl(ptr, *this);
|
491 |
+
}
|
492 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src);
|
493 |
+
};
|
494 |
+
|
495 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
|
496 |
+
((__nv_fp8x4_e5m2*)(&data.x))->__x =
|
497 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
498 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
499 |
+
((__nv_fp8x4_e5m2*)(&data.y))->__x =
|
500 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
501 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
502 |
+
}
|
503 |
+
|
504 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) {
|
505 |
+
data = *((uint2*)ptr);
|
506 |
+
}
|
507 |
+
|
508 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const {
|
509 |
+
*((uint2*)ptr) = data;
|
510 |
+
}
|
511 |
+
|
512 |
+
FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(__nv_fp8_e5m2* dst,
|
513 |
+
const __nv_fp8_e5m2* src) {
|
514 |
+
*((uint2*)dst) = *((uint2*)src);
|
515 |
+
}
|
516 |
+
|
517 |
+
// __nv_fp8_e5m2 x 16 or more
|
518 |
+
|
519 |
+
template <size_t vec_size>
|
520 |
+
struct vec_t<__nv_fp8_e5m2, vec_size> {
|
521 |
+
uint4 data[vec_size / 16];
|
522 |
+
|
523 |
+
FLASHINFER_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)data)[i]; }
|
524 |
+
FLASHINFER_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
|
525 |
+
return ((const __nv_fp8_e5m2*)data)[i];
|
526 |
+
}
|
527 |
+
FLASHINFER_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); }
|
528 |
+
FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) {
|
529 |
+
#pragma unroll
|
530 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
531 |
+
((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x =
|
532 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
533 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
534 |
+
((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x =
|
535 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
536 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
537 |
+
((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x =
|
538 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
539 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
540 |
+
((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x =
|
541 |
+
(__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) |
|
542 |
+
(__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
|
543 |
+
}
|
544 |
+
}
|
545 |
+
FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr) {
|
546 |
+
#pragma unroll
|
547 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
548 |
+
data[i] = ((uint4*)ptr)[i];
|
549 |
+
}
|
550 |
+
}
|
551 |
+
FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const {
|
552 |
+
#pragma unroll
|
553 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
554 |
+
((uint4*)ptr)[i] = data[i];
|
555 |
+
}
|
556 |
+
}
|
557 |
+
template <typename T>
|
558 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
559 |
+
cast_from_impl(*this, src);
|
560 |
+
}
|
561 |
+
template <typename T>
|
562 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
563 |
+
cast_load_impl(*this, ptr);
|
564 |
+
}
|
565 |
+
template <typename T>
|
566 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
567 |
+
cast_store_impl(ptr, *this);
|
568 |
+
}
|
569 |
+
FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
|
570 |
+
#pragma unroll
|
571 |
+
for (size_t i = 0; i < vec_size / 16; ++i) {
|
572 |
+
((uint4*)dst)[i] = ((uint4*)src)[i];
|
573 |
+
}
|
574 |
+
}
|
575 |
+
};
|
576 |
+
#endif
|
577 |
+
|
578 |
+
/******************* vec_t<half> *******************/
|
579 |
+
|
580 |
+
// half x 1
|
581 |
+
template <>
|
582 |
+
struct vec_t<half, 1> {
|
583 |
+
half data;
|
584 |
+
|
585 |
+
FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
|
586 |
+
FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; }
|
587 |
+
FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
|
588 |
+
FLASHINFER_INLINE void fill(half val);
|
589 |
+
FLASHINFER_INLINE void load(const half* ptr);
|
590 |
+
FLASHINFER_INLINE void store(half* ptr) const;
|
591 |
+
template <typename T>
|
592 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
|
593 |
+
cast_from_impl(*this, src);
|
594 |
+
}
|
595 |
+
template <typename T>
|
596 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
597 |
+
cast_load_impl(*this, ptr);
|
598 |
+
}
|
599 |
+
template <typename T>
|
600 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
601 |
+
cast_store_impl(ptr, *this);
|
602 |
+
}
|
603 |
+
|
604 |
+
FLASHINFER_INLINE static void memcpy(half* dst, const half* src);
|
605 |
+
};
|
606 |
+
|
607 |
+
FLASHINFER_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
|
608 |
+
|
609 |
+
FLASHINFER_INLINE void vec_t<half, 1>::load(const half* ptr) { data = *ptr; }
|
610 |
+
|
611 |
+
FLASHINFER_INLINE void vec_t<half, 1>::store(half* ptr) const { *ptr = data; }
|
612 |
+
|
613 |
+
FLASHINFER_INLINE void vec_t<half, 1>::memcpy(half* dst, const half* src) { *dst = *src; }
|
614 |
+
|
615 |
+
// half x 2
|
616 |
+
template <>
|
617 |
+
struct vec_t<half, 2> {
|
618 |
+
half2 data;
|
619 |
+
|
620 |
+
FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
|
621 |
+
FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; }
|
622 |
+
FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
|
623 |
+
FLASHINFER_INLINE void fill(half val);
|
624 |
+
FLASHINFER_INLINE void load(const half* ptr);
|
625 |
+
FLASHINFER_INLINE void store(half* ptr) const;
|
626 |
+
template <typename T>
|
627 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
|
628 |
+
cast_from_impl(*this, src);
|
629 |
+
}
|
630 |
+
template <typename T>
|
631 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
632 |
+
cast_load_impl(*this, ptr);
|
633 |
+
}
|
634 |
+
template <typename T>
|
635 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
636 |
+
cast_store_impl(ptr, *this);
|
637 |
+
}
|
638 |
+
|
639 |
+
FLASHINFER_INLINE static void memcpy(half* dst, const half* src);
|
640 |
+
};
|
641 |
+
|
642 |
+
FLASHINFER_INLINE void vec_t<half, 2>::fill(half val) { data = make_half2(val, val); }
|
643 |
+
|
644 |
+
FLASHINFER_INLINE void vec_t<half, 2>::load(const half* ptr) { data = *((half2*)ptr); }
|
645 |
+
|
646 |
+
FLASHINFER_INLINE void vec_t<half, 2>::store(half* ptr) const { *((half2*)ptr) = data; }
|
647 |
+
|
648 |
+
FLASHINFER_INLINE void vec_t<half, 2>::memcpy(half* dst, const half* src) {
|
649 |
+
*((half2*)dst) = *((half2*)src);
|
650 |
+
}
|
651 |
+
|
652 |
+
// half x 4
|
653 |
+
|
654 |
+
template <>
|
655 |
+
struct vec_t<half, 4> {
|
656 |
+
uint2 data;
|
657 |
+
|
658 |
+
FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
|
659 |
+
FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; }
|
660 |
+
FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
|
661 |
+
FLASHINFER_INLINE void fill(half val);
|
662 |
+
FLASHINFER_INLINE void load(const half* ptr);
|
663 |
+
FLASHINFER_INLINE void store(half* ptr) const;
|
664 |
+
template <typename T>
|
665 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
|
666 |
+
cast_from_impl(*this, src);
|
667 |
+
}
|
668 |
+
template <typename T>
|
669 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
670 |
+
cast_load_impl(*this, ptr);
|
671 |
+
}
|
672 |
+
template <typename T>
|
673 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
674 |
+
cast_store_impl(ptr, *this);
|
675 |
+
}
|
676 |
+
FLASHINFER_INLINE static void memcpy(half* dst, const half* src);
|
677 |
+
};
|
678 |
+
|
679 |
+
FLASHINFER_INLINE void vec_t<half, 4>::fill(half val) {
|
680 |
+
*(half2*)(&data.x) = make_half2(val, val);
|
681 |
+
*(half2*)(&data.y) = make_half2(val, val);
|
682 |
+
}
|
683 |
+
|
684 |
+
FLASHINFER_INLINE void vec_t<half, 4>::load(const half* ptr) { data = *((uint2*)ptr); }
|
685 |
+
|
686 |
+
FLASHINFER_INLINE void vec_t<half, 4>::store(half* ptr) const { *((uint2*)ptr) = data; }
|
687 |
+
|
688 |
+
FLASHINFER_INLINE void vec_t<half, 4>::memcpy(half* dst, const half* src) {
|
689 |
+
*((uint2*)dst) = *((uint2*)src);
|
690 |
+
}
|
691 |
+
|
692 |
+
// half x 8 or more
|
693 |
+
|
694 |
+
template <size_t vec_size>
|
695 |
+
struct vec_t<half, vec_size> {
|
696 |
+
uint4 data[vec_size / 8];
|
697 |
+
FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)data)[i]; }
|
698 |
+
FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)data)[i]; }
|
699 |
+
FLASHINFER_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
|
700 |
+
FLASHINFER_INLINE void fill(half val) {
|
701 |
+
#pragma unroll
|
702 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
703 |
+
*(half2*)(&(data[i].x)) = make_half2(val, val);
|
704 |
+
*(half2*)(&(data[i].y)) = make_half2(val, val);
|
705 |
+
*(half2*)(&(data[i].z)) = make_half2(val, val);
|
706 |
+
*(half2*)(&(data[i].w)) = make_half2(val, val);
|
707 |
+
}
|
708 |
+
}
|
709 |
+
FLASHINFER_INLINE void load(const half* ptr) {
|
710 |
+
#pragma unroll
|
711 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
712 |
+
data[i] = ((uint4*)ptr)[i];
|
713 |
+
}
|
714 |
+
}
|
715 |
+
FLASHINFER_INLINE void store(half* ptr) const {
|
716 |
+
#pragma unroll
|
717 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
718 |
+
((uint4*)ptr)[i] = data[i];
|
719 |
+
}
|
720 |
+
}
|
721 |
+
template <typename T>
|
722 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
723 |
+
cast_from_impl(*this, src);
|
724 |
+
}
|
725 |
+
template <typename T>
|
726 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
727 |
+
cast_load_impl(*this, ptr);
|
728 |
+
}
|
729 |
+
template <typename T>
|
730 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
731 |
+
cast_store_impl(ptr, *this);
|
732 |
+
}
|
733 |
+
FLASHINFER_INLINE static void memcpy(half* dst, const half* src) {
|
734 |
+
#pragma unroll
|
735 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
736 |
+
((uint4*)dst)[i] = ((uint4*)src)[i];
|
737 |
+
}
|
738 |
+
}
|
739 |
+
};
|
740 |
+
|
741 |
+
/******************* vec_t<nv_bfloat16> *******************/
|
742 |
+
|
743 |
+
// nv_bfloat16 x 1
|
744 |
+
template <>
|
745 |
+
struct vec_t<nv_bfloat16, 1> {
|
746 |
+
nv_bfloat16 data;
|
747 |
+
FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; }
|
748 |
+
FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
|
749 |
+
return ((const nv_bfloat16*)(&data))[i];
|
750 |
+
}
|
751 |
+
FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
|
752 |
+
FLASHINFER_INLINE void fill(nv_bfloat16 val);
|
753 |
+
FLASHINFER_INLINE void load(const nv_bfloat16* ptr);
|
754 |
+
FLASHINFER_INLINE void store(nv_bfloat16* ptr) const;
|
755 |
+
template <typename T>
|
756 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
|
757 |
+
cast_from_impl(*this, src);
|
758 |
+
}
|
759 |
+
template <typename T>
|
760 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
761 |
+
cast_load_impl(*this, ptr);
|
762 |
+
}
|
763 |
+
template <typename T>
|
764 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
765 |
+
cast_store_impl(ptr, *this);
|
766 |
+
}
|
767 |
+
FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
|
768 |
+
};
|
769 |
+
|
770 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) { data = val; }
|
771 |
+
|
772 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) { data = *ptr; }
|
773 |
+
|
774 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const { *ptr = data; }
|
775 |
+
|
776 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
|
777 |
+
*dst = *src;
|
778 |
+
}
|
779 |
+
|
780 |
+
// nv_bfloat16 x 2
|
781 |
+
template <>
|
782 |
+
struct vec_t<nv_bfloat16, 2> {
|
783 |
+
nv_bfloat162 data;
|
784 |
+
|
785 |
+
FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; }
|
786 |
+
FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
|
787 |
+
return ((const nv_bfloat16*)(&data))[i];
|
788 |
+
}
|
789 |
+
FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
|
790 |
+
FLASHINFER_INLINE void fill(nv_bfloat16 val);
|
791 |
+
FLASHINFER_INLINE void load(const nv_bfloat16* ptr);
|
792 |
+
FLASHINFER_INLINE void store(nv_bfloat16* ptr) const;
|
793 |
+
template <typename T>
|
794 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
|
795 |
+
cast_from_impl(*this, src);
|
796 |
+
}
|
797 |
+
template <typename T>
|
798 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
799 |
+
cast_load_impl(*this, ptr);
|
800 |
+
}
|
801 |
+
template <typename T>
|
802 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
803 |
+
cast_store_impl(ptr, *this);
|
804 |
+
}
|
805 |
+
FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
|
806 |
+
};
|
807 |
+
|
808 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
|
809 |
+
data = make_bfloat162(val, val);
|
810 |
+
}
|
811 |
+
|
812 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
|
813 |
+
data = *((nv_bfloat162*)ptr);
|
814 |
+
}
|
815 |
+
|
816 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
|
817 |
+
*((nv_bfloat162*)ptr) = data;
|
818 |
+
}
|
819 |
+
|
820 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
|
821 |
+
*((nv_bfloat162*)dst) = *((nv_bfloat162*)src);
|
822 |
+
}
|
823 |
+
|
824 |
+
// nv_bfloat16 x 4
|
825 |
+
|
826 |
+
template <>
|
827 |
+
struct vec_t<nv_bfloat16, 4> {
|
828 |
+
uint2 data;
|
829 |
+
|
830 |
+
FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; }
|
831 |
+
FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
|
832 |
+
return ((const nv_bfloat16*)(&data))[i];
|
833 |
+
}
|
834 |
+
FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
|
835 |
+
FLASHINFER_INLINE void fill(nv_bfloat16 val);
|
836 |
+
FLASHINFER_INLINE void load(const nv_bfloat16* ptr);
|
837 |
+
FLASHINFER_INLINE void store(nv_bfloat16* ptr) const;
|
838 |
+
template <typename T>
|
839 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 4>& src) {
|
840 |
+
cast_from_impl(*this, src);
|
841 |
+
}
|
842 |
+
template <typename T>
|
843 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
844 |
+
cast_load_impl(*this, ptr);
|
845 |
+
}
|
846 |
+
template <typename T>
|
847 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
848 |
+
cast_store_impl(ptr, *this);
|
849 |
+
}
|
850 |
+
FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
|
851 |
+
};
|
852 |
+
|
853 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
|
854 |
+
*(nv_bfloat162*)(&data.x) = make_bfloat162(val, val);
|
855 |
+
*(nv_bfloat162*)(&data.y) = make_bfloat162(val, val);
|
856 |
+
}
|
857 |
+
|
858 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
|
859 |
+
data = *((uint2*)ptr);
|
860 |
+
}
|
861 |
+
|
862 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
|
863 |
+
*((uint2*)ptr) = data;
|
864 |
+
}
|
865 |
+
|
866 |
+
FLASHINFER_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
|
867 |
+
*((uint2*)dst) = *((uint2*)src);
|
868 |
+
}
|
869 |
+
|
870 |
+
// nv_bfloat16 x 8 or more
|
871 |
+
|
872 |
+
template <size_t vec_size>
|
873 |
+
struct vec_t<nv_bfloat16, vec_size> {
|
874 |
+
uint4 data[vec_size / 8];
|
875 |
+
|
876 |
+
FLASHINFER_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)data)[i]; }
|
877 |
+
FLASHINFER_INLINE const nv_bfloat16& operator[](size_t i) const {
|
878 |
+
return ((const nv_bfloat16*)data)[i];
|
879 |
+
}
|
880 |
+
FLASHINFER_INLINE nv_bfloat16* ptr() { return reinterpret_cast<nv_bfloat16*>(&data); }
|
881 |
+
FLASHINFER_INLINE void fill(nv_bfloat16 val) {
|
882 |
+
#pragma unoll
|
883 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
884 |
+
*(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val);
|
885 |
+
*(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val);
|
886 |
+
*(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val);
|
887 |
+
*(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val);
|
888 |
+
}
|
889 |
+
}
|
890 |
+
FLASHINFER_INLINE void load(const nv_bfloat16* ptr) {
|
891 |
+
#pragma unoll
|
892 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
893 |
+
data[i] = ((uint4*)ptr)[i];
|
894 |
+
}
|
895 |
+
}
|
896 |
+
FLASHINFER_INLINE void store(nv_bfloat16* ptr) const {
|
897 |
+
#pragma unoll
|
898 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
899 |
+
((uint4*)ptr)[i] = data[i];
|
900 |
+
}
|
901 |
+
}
|
902 |
+
template <typename T>
|
903 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
904 |
+
cast_from_impl(*this, src);
|
905 |
+
}
|
906 |
+
template <typename T>
|
907 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
908 |
+
cast_load_impl(*this, ptr);
|
909 |
+
}
|
910 |
+
template <typename T>
|
911 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
912 |
+
cast_store_impl(ptr, *this);
|
913 |
+
}
|
914 |
+
FLASHINFER_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src) {
|
915 |
+
#pragma unoll
|
916 |
+
for (size_t i = 0; i < vec_size / 8; ++i) {
|
917 |
+
((uint4*)dst)[i] = ((uint4*)src)[i];
|
918 |
+
}
|
919 |
+
}
|
920 |
+
};
|
921 |
+
|
922 |
+
/******************* vec_t<float> *******************/
|
923 |
+
|
924 |
+
// float x 1
|
925 |
+
|
926 |
+
template <>
|
927 |
+
struct vec_t<float, 1> {
|
928 |
+
float data;
|
929 |
+
|
930 |
+
FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
|
931 |
+
FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; }
|
932 |
+
FLASHINFER_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
933 |
+
FLASHINFER_INLINE void fill(float val);
|
934 |
+
FLASHINFER_INLINE void load(const float* ptr);
|
935 |
+
FLASHINFER_INLINE void store(float* ptr) const;
|
936 |
+
template <typename T>
|
937 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 1>& src) {
|
938 |
+
cast_from_impl(*this, src);
|
939 |
+
}
|
940 |
+
template <typename T>
|
941 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
942 |
+
cast_load_impl(*this, ptr);
|
943 |
+
}
|
944 |
+
template <typename T>
|
945 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
946 |
+
cast_store_impl(ptr, *this);
|
947 |
+
}
|
948 |
+
FLASHINFER_INLINE static void memcpy(float* dst, const float* src);
|
949 |
+
};
|
950 |
+
|
951 |
+
FLASHINFER_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
|
952 |
+
|
953 |
+
FLASHINFER_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
|
954 |
+
|
955 |
+
FLASHINFER_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
|
956 |
+
|
957 |
+
FLASHINFER_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) { *dst = *src; }
|
958 |
+
|
959 |
+
// float x 2
|
960 |
+
|
961 |
+
template <>
|
962 |
+
struct vec_t<float, 2> {
|
963 |
+
float2 data;
|
964 |
+
|
965 |
+
FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
|
966 |
+
FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; }
|
967 |
+
FLASHINFER_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
968 |
+
FLASHINFER_INLINE void fill(float val);
|
969 |
+
FLASHINFER_INLINE void load(const float* ptr);
|
970 |
+
FLASHINFER_INLINE void store(float* ptr) const;
|
971 |
+
template <typename T>
|
972 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, 2>& src) {
|
973 |
+
cast_from_impl(*this, src);
|
974 |
+
}
|
975 |
+
template <typename T>
|
976 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
977 |
+
cast_load_impl(*this, ptr);
|
978 |
+
}
|
979 |
+
template <typename T>
|
980 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
981 |
+
cast_store_impl(ptr, *this);
|
982 |
+
}
|
983 |
+
FLASHINFER_INLINE static void memcpy(float* dst, const float* src);
|
984 |
+
};
|
985 |
+
|
986 |
+
FLASHINFER_INLINE void vec_t<float, 2>::fill(float val) { data = make_float2(val, val); }
|
987 |
+
|
988 |
+
FLASHINFER_INLINE void vec_t<float, 2>::load(const float* ptr) { data = *((float2*)ptr); }
|
989 |
+
|
990 |
+
FLASHINFER_INLINE void vec_t<float, 2>::store(float* ptr) const { *((float2*)ptr) = data; }
|
991 |
+
|
992 |
+
FLASHINFER_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
|
993 |
+
*((float2*)dst) = *((float2*)src);
|
994 |
+
}
|
995 |
+
|
996 |
+
// float x 4 or more
|
997 |
+
template <size_t vec_size>
|
998 |
+
struct vec_t<float, vec_size> {
|
999 |
+
float4 data[vec_size / 4];
|
1000 |
+
|
1001 |
+
FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
|
1002 |
+
FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; }
|
1003 |
+
FLASHINFER_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
|
1004 |
+
FLASHINFER_INLINE void fill(float val) {
|
1005 |
+
#pragma unroll
|
1006 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1007 |
+
data[i] = make_float4(val, val, val, val);
|
1008 |
+
}
|
1009 |
+
}
|
1010 |
+
FLASHINFER_INLINE void load(const float* ptr) {
|
1011 |
+
#pragma unroll
|
1012 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1013 |
+
data[i] = ((float4*)ptr)[i];
|
1014 |
+
}
|
1015 |
+
}
|
1016 |
+
FLASHINFER_INLINE void store(float* ptr) const {
|
1017 |
+
#pragma unroll
|
1018 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1019 |
+
((float4*)ptr)[i] = data[i];
|
1020 |
+
}
|
1021 |
+
}
|
1022 |
+
template <typename T>
|
1023 |
+
FLASHINFER_INLINE void cast_from(const vec_t<T, vec_size>& src) {
|
1024 |
+
cast_from_impl(*this, src);
|
1025 |
+
}
|
1026 |
+
template <typename T>
|
1027 |
+
FLASHINFER_INLINE void cast_load(const T* ptr) {
|
1028 |
+
cast_load_impl(*this, ptr);
|
1029 |
+
}
|
1030 |
+
template <typename T>
|
1031 |
+
FLASHINFER_INLINE void cast_store(T* ptr) const {
|
1032 |
+
cast_store_impl(ptr, *this);
|
1033 |
+
}
|
1034 |
+
FLASHINFER_INLINE static void memcpy(float* dst, const float* src) {
|
1035 |
+
#pragma unroll
|
1036 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1037 |
+
((float4*)dst)[i] = ((float4*)src)[i];
|
1038 |
+
}
|
1039 |
+
}
|
1040 |
+
};
|
1041 |
+
|
1042 |
+
/******************* vec_t type cast *******************/
|
1043 |
+
|
1044 |
+
template <typename dst_t, typename src_t, size_t vec_size>
|
1045 |
+
FLASHINFER_INLINE void vec_cast(dst_t* dst, const src_t* src) {
|
1046 |
+
#pragma unroll
|
1047 |
+
for (size_t i = 0; i < vec_size; ++i) {
|
1048 |
+
dst[i] = src[i];
|
1049 |
+
}
|
1050 |
+
}
|
1051 |
+
|
1052 |
+
template <size_t vec_size>
|
1053 |
+
FLASHINFER_INLINE void vec_cast<float, half>(float* dst, const half* src) {
|
1054 |
+
#pragma unroll
|
1055 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1056 |
+
((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
|
1057 |
+
}
|
1058 |
+
}
|
1059 |
+
|
1060 |
+
template <size_t vec_size>
|
1061 |
+
FLASHINFER_INLINE void vec_cast<half, float>(half* dst, const float* src) {
|
1062 |
+
#pragma unroll
|
1063 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1064 |
+
((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
|
1065 |
+
}
|
1066 |
+
}
|
1067 |
+
|
1068 |
+
template <size_t vec_size>
|
1069 |
+
FLASHINFER_INLINE void vec_cast<float, nv_bfloat16>(float* dst, const nv_bfloat16* src) {
|
1070 |
+
#pragma unroll
|
1071 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1072 |
+
((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
|
1073 |
+
}
|
1074 |
+
}
|
1075 |
+
|
1076 |
+
template <size_t vec_size>
|
1077 |
+
FLASHINFER_INLINE void vec_cast<nv_bfloat16, float>(nv_bfloat16* dst, const float* src) {
|
1078 |
+
#pragma unroll
|
1079 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1080 |
+
((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
|
1081 |
+
}
|
1082 |
+
}
|
1083 |
+
|
1084 |
+
template <size_t vec_size>
|
1085 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
|
1086 |
+
const vec_t<half, vec_size>& src) {
|
1087 |
+
if constexpr (vec_size == 1) {
|
1088 |
+
dst.data = float(src.data);
|
1089 |
+
} else {
|
1090 |
+
#pragma unroll
|
1091 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1092 |
+
((float2*)(&dst.data))[i] = __half22float2(((half2*)(&src.data))[i]);
|
1093 |
+
}
|
1094 |
+
}
|
1095 |
+
}
|
1096 |
+
|
1097 |
+
template <size_t vec_size>
|
1098 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<half, vec_size>& dst,
|
1099 |
+
const vec_t<float, vec_size>& src) {
|
1100 |
+
if constexpr (vec_size == 1) {
|
1101 |
+
dst.data = half(src.data);
|
1102 |
+
} else {
|
1103 |
+
#pragma unroll
|
1104 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1105 |
+
((half2*)(&dst.data))[i] = __float22half2_rn(((float2*)(&src.data))[i]);
|
1106 |
+
}
|
1107 |
+
}
|
1108 |
+
}
|
1109 |
+
|
1110 |
+
template <size_t vec_size>
|
1111 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
|
1112 |
+
const vec_t<nv_bfloat16, vec_size>& src) {
|
1113 |
+
if constexpr (vec_size == 1) {
|
1114 |
+
dst.data = float(src.data);
|
1115 |
+
} else {
|
1116 |
+
#pragma unroll
|
1117 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1118 |
+
((float2*)(&dst.data))[i] = __bfloat1622float2(((nv_bfloat162*)(&src.data))[i]);
|
1119 |
+
}
|
1120 |
+
}
|
1121 |
+
}
|
1122 |
+
|
1123 |
+
template <size_t vec_size>
|
1124 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<nv_bfloat16, vec_size>& dst,
|
1125 |
+
const vec_t<float, vec_size>& src) {
|
1126 |
+
if constexpr (vec_size == 1) {
|
1127 |
+
dst.data = nv_bfloat16(src.data);
|
1128 |
+
} else {
|
1129 |
+
#pragma unroll
|
1130 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1131 |
+
((nv_bfloat162*)(&dst.data))[i] = __float22bfloat162_rn(((float2*)(&src.data))[i]);
|
1132 |
+
}
|
1133 |
+
}
|
1134 |
+
}
|
1135 |
+
|
1136 |
+
#ifdef FLASHINFER_ENABLE_FP8
|
1137 |
+
|
1138 |
+
template <size_t vec_size>
|
1139 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
|
1140 |
+
const vec_t<__nv_fp8_e4m3, vec_size>& src) {
|
1141 |
+
if constexpr (vec_size == 1) {
|
1142 |
+
dst.data = float(src.data);
|
1143 |
+
} else if constexpr (vec_size == 2) {
|
1144 |
+
*(float2*)(&dst.data) = float2(*(__nv_fp8x2_e4m3*)(&src.data));
|
1145 |
+
} else {
|
1146 |
+
#pragma unroll
|
1147 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1148 |
+
((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3*)(&src.data))[i]);
|
1149 |
+
}
|
1150 |
+
}
|
1151 |
+
}
|
1152 |
+
|
1153 |
+
template <size_t vec_size>
|
1154 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<half, vec_size>& dst,
|
1155 |
+
const vec_t<__nv_fp8_e4m3, vec_size>& src) {
|
1156 |
+
if constexpr (vec_size == 1) {
|
1157 |
+
dst.data = float(src.data);
|
1158 |
+
} else {
|
1159 |
+
#pragma unroll
|
1160 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1161 |
+
((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3*)(&src.data))[i]);
|
1162 |
+
}
|
1163 |
+
}
|
1164 |
+
}
|
1165 |
+
|
1166 |
+
template <size_t vec_size>
|
1167 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst,
|
1168 |
+
const vec_t<float, vec_size>& src) {
|
1169 |
+
if constexpr (vec_size == 1) {
|
1170 |
+
dst.data = __nv_fp8_e4m3(src.data);
|
1171 |
+
} else if constexpr (vec_size == 2) {
|
1172 |
+
*(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(float2*)(&src.data));
|
1173 |
+
} else {
|
1174 |
+
#pragma unroll
|
1175 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1176 |
+
((__nv_fp8x4_e4m3*)(&dst.data))[i] = __nv_fp8x4_e4m3(((float4*)(&src.data))[i]);
|
1177 |
+
}
|
1178 |
+
}
|
1179 |
+
}
|
1180 |
+
|
1181 |
+
template <size_t vec_size>
|
1182 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e4m3, vec_size>& dst,
|
1183 |
+
const vec_t<half, vec_size>& src) {
|
1184 |
+
if constexpr (vec_size == 1) {
|
1185 |
+
dst.data = __nv_fp8_e4m3(src.data);
|
1186 |
+
} else if constexpr (vec_size == 2) {
|
1187 |
+
*(__nv_fp8x2_e4m3*)(&dst.data) = __nv_fp8x2_e4m3(*(half2*)(&src.data));
|
1188 |
+
} else {
|
1189 |
+
#pragma unroll
|
1190 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1191 |
+
// NOTE(Zihao): need to double check if we properly handle flo and fhi
|
1192 |
+
((__nv_fp8x4_e4m3*)(&dst.data))[i] =
|
1193 |
+
__nv_fp8x4_e4m3(((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]);
|
1194 |
+
}
|
1195 |
+
}
|
1196 |
+
}
|
1197 |
+
|
1198 |
+
template <size_t vec_size>
|
1199 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<float, vec_size>& dst,
|
1200 |
+
const vec_t<__nv_fp8_e5m2, vec_size>& src) {
|
1201 |
+
if constexpr (vec_size == 1) {
|
1202 |
+
dst.data = float(src.data);
|
1203 |
+
} else if constexpr (vec_size == 2) {
|
1204 |
+
*(float2*)(&dst.data) = float2(*(__nv_fp8x2_e5m2*)(&src.data));
|
1205 |
+
} else {
|
1206 |
+
#pragma unroll
|
1207 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1208 |
+
((float4*)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2*)(&src.data))[i]);
|
1209 |
+
}
|
1210 |
+
}
|
1211 |
+
}
|
1212 |
+
|
1213 |
+
template <size_t vec_size>
|
1214 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<half, vec_size>& dst,
|
1215 |
+
const vec_t<__nv_fp8_e5m2, vec_size>& src) {
|
1216 |
+
if constexpr (vec_size == 1) {
|
1217 |
+
dst.data = float(src.data);
|
1218 |
+
} else {
|
1219 |
+
#pragma unroll
|
1220 |
+
for (size_t i = 0; i < vec_size / 2; ++i) {
|
1221 |
+
((half2*)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2*)(&src.data))[i]);
|
1222 |
+
}
|
1223 |
+
}
|
1224 |
+
}
|
1225 |
+
|
1226 |
+
template <size_t vec_size>
|
1227 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst,
|
1228 |
+
const vec_t<float, vec_size>& src) {
|
1229 |
+
if constexpr (vec_size == 1) {
|
1230 |
+
dst.data = __nv_fp8_e5m2(src.data);
|
1231 |
+
} else if constexpr (vec_size == 2) {
|
1232 |
+
*(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(float2*)(&src.data));
|
1233 |
+
} else {
|
1234 |
+
#pragma unroll
|
1235 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1236 |
+
((__nv_fp8x4_e5m2*)(&dst.data))[i] = __nv_fp8x4_e5m2(((float4*)(&src.data))[i]);
|
1237 |
+
}
|
1238 |
+
}
|
1239 |
+
}
|
1240 |
+
|
1241 |
+
template <size_t vec_size>
|
1242 |
+
FLASHINFER_INLINE void cast_from_impl(vec_t<__nv_fp8_e5m2, vec_size>& dst,
|
1243 |
+
const vec_t<half, vec_size>& src) {
|
1244 |
+
if constexpr (vec_size == 1) {
|
1245 |
+
dst.data = __nv_fp8_e5m2(src.data);
|
1246 |
+
} else if constexpr (vec_size == 2) {
|
1247 |
+
*(__nv_fp8x2_e5m2*)(&dst.data) = __nv_fp8x2_e5m2(*(half2*)(&src.data));
|
1248 |
+
} else {
|
1249 |
+
#pragma unroll
|
1250 |
+
for (size_t i = 0; i < vec_size / 4; ++i) {
|
1251 |
+
// NOTE(Zihao): need to double check if we properly handle flo and fhi
|
1252 |
+
((__nv_fp8x4_e5m2*)(&dst.data))[i] =
|
1253 |
+
__nv_fp8x4_e5m2(((half2*)(&src.data))[i * 2], ((half2*)(&src.data))[i * 2 + 1]);
|
1254 |
+
}
|
1255 |
+
}
|
1256 |
+
}
|
1257 |
+
|
1258 |
+
#endif // FLASHINFER_ENABLE_FP8
|
1259 |
+
|
1260 |
+
} // namespace flashinfer
|
1261 |
+
|
1262 |
+
#endif // VEC_DTYPES_CUH_
|
punica_kernels/punica_ops.cc
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <c10/cuda/CUDAStream.h>
|
2 |
+
#include <cuda_bf16.h>
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
#include <torch/all.h>
|
5 |
+
|
6 |
+
#include <cstdint>
|
7 |
+
|
8 |
+
#include "bgmv/bgmv_config.h"
|
9 |
+
#include "sgmv/sgmv.h"
|
10 |
+
#include "sgmv_flashinfer/sgmv_config.h"
|
11 |
+
|
12 |
+
//namespace
|
13 |
+
//{
|
14 |
+
|
15 |
+
//====== utils ======
|
16 |
+
|
17 |
+
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b)
|
18 |
+
{
|
19 |
+
return (uint64_t(a) << 32) | uint64_t(b);
|
20 |
+
}
|
21 |
+
|
22 |
+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
23 |
+
|
24 |
+
#define CHECK_CONTIGUOUS(x) \
|
25 |
+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
26 |
+
|
27 |
+
#define CHECK_INPUT(x) \
|
28 |
+
CHECK_CUDA(x); \
|
29 |
+
CHECK_CONTIGUOUS(x)
|
30 |
+
|
31 |
+
#define CHECK_DIM(d, x) \
|
32 |
+
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
33 |
+
|
34 |
+
#define CHECK_EQ(a, b) \
|
35 |
+
TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
36 |
+
|
37 |
+
#define CHECK_GE(a, b) \
|
38 |
+
TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
39 |
+
|
40 |
+
//====== dispatch pytorch dtype ======
|
41 |
+
|
42 |
+
#define _DISPATCH_SWITCH(cond, ...) \
|
43 |
+
[&]() -> bool { \
|
44 |
+
switch (cond) \
|
45 |
+
{ \
|
46 |
+
__VA_ARGS__ \
|
47 |
+
default: \
|
48 |
+
return false; \
|
49 |
+
} \
|
50 |
+
}()
|
51 |
+
|
52 |
+
#define _DISPATCH_DTYPE_CASE(enum_type, c_type_, ...) \
|
53 |
+
case enum_type: \
|
54 |
+
{ \
|
55 |
+
using c_type = c_type_; \
|
56 |
+
return __VA_ARGS__(); \
|
57 |
+
}
|
58 |
+
|
59 |
+
#define _DISPATCH_DTYPE_CASES(...) \
|
60 |
+
_DISPATCH_DTYPE_CASE(at::ScalarType::Half, nv_half, __VA_ARGS__) \
|
61 |
+
_DISPATCH_DTYPE_CASE(at::ScalarType::BFloat16, nv_bfloat16, __VA_ARGS__)
|
62 |
+
|
63 |
+
#define DISPATCH_TORCH_DTYPE(scalar_type, ...) \
|
64 |
+
_DISPATCH_SWITCH(scalar_type, _DISPATCH_DTYPE_CASES(__VA_ARGS__))
|
65 |
+
|
66 |
+
//====== bgmv ======
|
67 |
+
|
68 |
+
template <typename T>
|
69 |
+
inline bool launch_bgmv_kernel(T *Y, const T *X, T **W,
|
70 |
+
const int64_t *lora_indices,
|
71 |
+
uint16_t in_features, uint16_t out_features,
|
72 |
+
int64_t y_offset, int64_t full_y_size,
|
73 |
+
int64_t batch_size,
|
74 |
+
int64_t layer_idx, float scale)
|
75 |
+
{
|
76 |
+
switch (pack_u32(in_features, out_features))
|
77 |
+
{
|
78 |
+
#define CASE_ONESIDE(_T, feat_in, feat_out) \
|
79 |
+
case pack_u32(feat_in, feat_out): \
|
80 |
+
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
81 |
+
full_y_size, batch_size, \
|
82 |
+
layer_idx, scale); \
|
83 |
+
break;
|
84 |
+
#define CASE(_T, narrow, wide) \
|
85 |
+
CASE_ONESIDE(T, narrow, wide) \
|
86 |
+
CASE_ONESIDE(T, wide, narrow)
|
87 |
+
|
88 |
+
FOR_BGMV_WIDE_NARROW(CASE, _)
|
89 |
+
#undef CASE
|
90 |
+
#undef CASE_ONESIDE
|
91 |
+
default:
|
92 |
+
return false;
|
93 |
+
}
|
94 |
+
|
95 |
+
return true;
|
96 |
+
}
|
97 |
+
|
98 |
+
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
|
99 |
+
torch::Tensor indicies, int64_t layer_idx, double scale)
|
100 |
+
{
|
101 |
+
CHECK_INPUT(y);
|
102 |
+
CHECK_INPUT(x);
|
103 |
+
CHECK_INPUT(w_ptr);
|
104 |
+
CHECK_INPUT(indicies);
|
105 |
+
|
106 |
+
CHECK_DIM(2, y);
|
107 |
+
CHECK_DIM(2, x);
|
108 |
+
CHECK_DIM(1, w_ptr);
|
109 |
+
CHECK_DIM(1, indicies);
|
110 |
+
|
111 |
+
int64_t B = x.size(0);
|
112 |
+
int64_t h_in = x.size(1);
|
113 |
+
int64_t h_out = y.size(1);
|
114 |
+
CHECK_EQ(indicies.size(0), x.size(0));
|
115 |
+
CHECK_EQ(y.size(0), x.size(0));
|
116 |
+
bool ok = false;
|
117 |
+
if (h_in < 65536 && h_out < 65536)
|
118 |
+
{
|
119 |
+
switch (x.scalar_type())
|
120 |
+
{
|
121 |
+
case at::ScalarType::Half:
|
122 |
+
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
123 |
+
static_cast<nv_half *>(x.data_ptr()),
|
124 |
+
static_cast<nv_half **>(w_ptr.data_ptr()),
|
125 |
+
indicies.data_ptr<int64_t>(), h_in, h_out, 0, h_out, B,
|
126 |
+
layer_idx, scale);
|
127 |
+
break;
|
128 |
+
case at::ScalarType::BFloat16:
|
129 |
+
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
130 |
+
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
131 |
+
static_cast<nv_bfloat16 **>(w_ptr.data_ptr()),
|
132 |
+
indicies.data_ptr<int64_t>(), h_in, h_out, 0, h_out, B,
|
133 |
+
layer_idx, scale);
|
134 |
+
break;
|
135 |
+
default:
|
136 |
+
break;
|
137 |
+
}
|
138 |
+
}
|
139 |
+
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
140 |
+
" dtype=", x.scalar_type());
|
141 |
+
}
|
142 |
+
|
143 |
+
//====== sgmv ======
|
144 |
+
|
145 |
+
void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
|
146 |
+
torch::Tensor s_start, torch::Tensor s_end,
|
147 |
+
torch::Tensor tmp, int64_t layer_idx)
|
148 |
+
{
|
149 |
+
CHECK_INPUT(y);
|
150 |
+
CHECK_INPUT(x);
|
151 |
+
CHECK_INPUT(w_ptr);
|
152 |
+
CHECK_INPUT(s_start);
|
153 |
+
CHECK_INPUT(s_end);
|
154 |
+
CHECK_INPUT(tmp);
|
155 |
+
|
156 |
+
CHECK_DIM(2, y);
|
157 |
+
CHECK_DIM(2, x);
|
158 |
+
CHECK_DIM(1, w_ptr);
|
159 |
+
CHECK_DIM(1, s_start);
|
160 |
+
CHECK_DIM(1, s_end);
|
161 |
+
CHECK_DIM(1, tmp);
|
162 |
+
|
163 |
+
int num_problems = s_start.size(0);
|
164 |
+
int d_in = x.size(1);
|
165 |
+
int d_out = y.size(1);
|
166 |
+
CHECK_EQ(tmp.size(0), static_cast<int64_t>(sgmv_tmp_size(num_problems)));
|
167 |
+
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
|
168 |
+
bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&]
|
169 |
+
{ return sgmv<c_type>((c_type *)y.data_ptr(), (c_type *)x.data_ptr(), (c_type **)w_ptr.data_ptr(),
|
170 |
+
s_start.data_ptr<int32_t>(), s_end.data_ptr<int32_t>(),
|
171 |
+
tmp.data_ptr<uint8_t>(), num_problems, d_in, d_out,
|
172 |
+
layer_idx, stream); });
|
173 |
+
TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type());
|
174 |
+
}
|
175 |
+
|
176 |
+
void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
|
177 |
+
torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int64_t layer_idx)
|
178 |
+
{
|
179 |
+
CHECK_INPUT(y);
|
180 |
+
CHECK_INPUT(x);
|
181 |
+
CHECK_INPUT(w_ptr);
|
182 |
+
CHECK_INPUT(s_start);
|
183 |
+
CHECK_INPUT(s_end);
|
184 |
+
CHECK_INPUT(tmp);
|
185 |
+
|
186 |
+
CHECK_DIM(2, y);
|
187 |
+
CHECK_DIM(2, x);
|
188 |
+
CHECK_DIM(1, w_ptr);
|
189 |
+
CHECK_DIM(1, s_start);
|
190 |
+
CHECK_DIM(1, s_end);
|
191 |
+
CHECK_DIM(1, tmp);
|
192 |
+
|
193 |
+
uint32_t num_problems = s_start.size(0);
|
194 |
+
uint32_t d_in = x.size(1);
|
195 |
+
uint32_t d_out = y.size(1);
|
196 |
+
CHECK_EQ(tmp.scalar_type(), at::ScalarType::Byte);
|
197 |
+
CHECK_EQ(tmp.size(0), 8 * 1024 * 1024);
|
198 |
+
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
|
199 |
+
|
200 |
+
#define CASE(_T, D_OUT) \
|
201 |
+
case D_OUT: \
|
202 |
+
return sgmv_shrink<c_type, D_OUT>( \
|
203 |
+
(c_type *)y.data_ptr(), (c_type *)x.data_ptr(), \
|
204 |
+
(c_type **)w_ptr.data_ptr(), s_start.data_ptr<int32_t>(), s_end.data_ptr<int32_t>(), \
|
205 |
+
tmp.data_ptr<uint8_t>(), num_problems, d_in, layer_idx, stream);
|
206 |
+
|
207 |
+
bool ok = DISPATCH_TORCH_DTYPE(x.scalar_type(), [&]
|
208 |
+
{
|
209 |
+
switch (d_out) {
|
210 |
+
FOR_SGMV_NARROW(CASE, c_type);
|
211 |
+
default:
|
212 |
+
return false;
|
213 |
+
} });
|
214 |
+
|
215 |
+
#undef CASE
|
216 |
+
TORCH_CHECK(ok, "No suitable kernel.", " dtype=", x.scalar_type(),
|
217 |
+
" d_out=", d_out);
|
218 |
+
}
|
219 |
+
//} // namespace
|
220 |
+
|
sgmv/sgmv.h
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <cuda_runtime.h>
|
3 |
+
|
4 |
+
#include <cstdint>
|
5 |
+
|
6 |
+
template <typename DType>
|
7 |
+
bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
|
8 |
+
void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx, cudaStream_t stream);
|
9 |
+
|
10 |
+
int64_t sgmv_tmp_size(int64_t num_problems);
|
sgmv/sgmv_cutlass.cu
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda_bf16.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
|
4 |
+
#include "sgmv_cutlass.cuh"
|
5 |
+
|
6 |
+
template bool sgmv<nv_half>(nv_half *y, nv_half *x, nv_half **w,
|
7 |
+
int32_t *s_start, int32_t *s_end,
|
8 |
+
void *tmp_d, int num_problems, int d_in, int d_out,
|
9 |
+
int layer_idx, cudaStream_t stream);
|
10 |
+
|
11 |
+
template bool sgmv<nv_bfloat16>(nv_bfloat16 *y, nv_bfloat16 *x, nv_bfloat16 **w,
|
12 |
+
int32_t *s_start, int32_t *s_end,
|
13 |
+
void *tmp_d, int num_problems, int d_in, int d_out,
|
14 |
+
int layer_idx, cudaStream_t stream);
|
sgmv/sgmv_cutlass.cuh
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <cuda_bf16.h>
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
|
6 |
+
#include <cstdint>
|
7 |
+
#include <cstdio>
|
8 |
+
|
9 |
+
#include "cutlass/cutlass.h"
|
10 |
+
#include "cutlass/gemm/device/gemm_grouped.h"
|
11 |
+
#include "cutlass/gemm/kernel/default_gemm_grouped.h"
|
12 |
+
#include "cutlass/layout/matrix.h"
|
13 |
+
#include "cutlass/numeric_types.h"
|
14 |
+
|
15 |
+
template <typename T>
|
16 |
+
struct cutlass_dtype {
|
17 |
+
using type = T;
|
18 |
+
};
|
19 |
+
|
20 |
+
template <>
|
21 |
+
struct cutlass_dtype<half> {
|
22 |
+
using type = cutlass::half_t;
|
23 |
+
};
|
24 |
+
|
25 |
+
template <>
|
26 |
+
struct cutlass_dtype<nv_bfloat16> {
|
27 |
+
using type = cutlass::bfloat16_t;
|
28 |
+
};
|
29 |
+
|
30 |
+
template <typename T>
|
31 |
+
__global__ void precompute_sgmv_args(cutlass::gemm::GemmCoord *all_problems,
|
32 |
+
T **ptr_y, T **ptr_x, T **ptr_w,
|
33 |
+
int64_t *ld_y, int64_t *ld_x,
|
34 |
+
int64_t *ld_w, T *y, T *x, T **w,
|
35 |
+
int32_t *s_start, int32_t *s_end,
|
36 |
+
int d_in, int d_out,
|
37 |
+
int layer_idx) {
|
38 |
+
int i = blockIdx.x;
|
39 |
+
int m = s_end[i] - s_start[i], k = d_in, n = d_out;
|
40 |
+
if (m <= 0) {
|
41 |
+
m = 0;
|
42 |
+
n = 0;
|
43 |
+
k = 0;
|
44 |
+
}
|
45 |
+
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
|
46 |
+
ptr_w[i] = w[i] + layer_idx * d_in * d_out;
|
47 |
+
ptr_x[i] = x + s_start[i] * d_in;
|
48 |
+
ptr_y[i] = y + s_start[i] * d_out;
|
49 |
+
ld_x[i] = k;
|
50 |
+
ld_w[i] = n;
|
51 |
+
ld_y[i] = n;
|
52 |
+
}
|
53 |
+
|
54 |
+
int64_t sgmv_tmp_size(int64_t num_problems) {
|
55 |
+
constexpr auto sz = sizeof(void *) * 3 + sizeof(int64_t) * 3 +
|
56 |
+
sizeof(cutlass::gemm::GemmCoord);
|
57 |
+
return sz * num_problems;
|
58 |
+
}
|
59 |
+
|
60 |
+
template <typename T>
|
61 |
+
inline T *alloc_from_buf(void **buf, int n) {
|
62 |
+
auto *p = (T *)*buf;
|
63 |
+
*buf = (void *)(p + n);
|
64 |
+
return p;
|
65 |
+
}
|
66 |
+
|
67 |
+
template <typename DType>
|
68 |
+
bool sgmv(DType *y, DType *x, DType **w, int32_t *s_start, int32_t *s_end,
|
69 |
+
void *tmp_d, int num_problems, int d_in, int d_out, int layer_idx,
|
70 |
+
cudaStream_t stream) {
|
71 |
+
using cutlass_t = typename cutlass_dtype<DType>::type;
|
72 |
+
|
73 |
+
auto ptr_Y = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
|
74 |
+
auto ptr_X = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
|
75 |
+
auto ptr_W = alloc_from_buf<cutlass_t *>(&tmp_d, num_problems);
|
76 |
+
auto ld_Y = alloc_from_buf<int64_t>(&tmp_d, num_problems);
|
77 |
+
auto ld_X = alloc_from_buf<int64_t>(&tmp_d, num_problems);
|
78 |
+
auto ld_W = alloc_from_buf<int64_t>(&tmp_d, num_problems);
|
79 |
+
auto all_problems =
|
80 |
+
alloc_from_buf<cutlass::gemm::GemmCoord>(&tmp_d, num_problems);
|
81 |
+
|
82 |
+
precompute_sgmv_args<<<num_problems, 1, 0, stream>>>(
|
83 |
+
all_problems, ptr_Y, ptr_X, ptr_W, ld_Y, ld_X, ld_W, (cutlass_t *)y,
|
84 |
+
(cutlass_t *)x, (cutlass_t **)w, s_start, s_end, d_in, d_out, layer_idx);
|
85 |
+
|
86 |
+
using cutlass::epilogue::thread::LinearCombination;
|
87 |
+
using cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle;
|
88 |
+
if (d_in < d_out) {
|
89 |
+
// Expand
|
90 |
+
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
91 |
+
cutlass_t, // Element A
|
92 |
+
cutlass::layout::RowMajor, // Layout A
|
93 |
+
cutlass::ComplexTransform::kNone, //
|
94 |
+
8, // Granularity A
|
95 |
+
cutlass_t, // Element B
|
96 |
+
cutlass::layout::RowMajor, // Layout B
|
97 |
+
cutlass::ComplexTransform::kNone, //
|
98 |
+
8, // Granularity B
|
99 |
+
cutlass_t, // Element C&D
|
100 |
+
cutlass::layout::RowMajor, // Layout C&D
|
101 |
+
float, // Element Accumulator
|
102 |
+
cutlass::arch::OpClassTensorOp, // Operator Class Tag
|
103 |
+
cutlass::arch::Sm80, // Architecture
|
104 |
+
cutlass::gemm::GemmShape<32, 128, 16>, // Thread Block Shape
|
105 |
+
cutlass::gemm::GemmShape<32, 64, 16>, // Warp Shape
|
106 |
+
cutlass::gemm::GemmShape<16, 8, 8>, // Instruction Shape
|
107 |
+
LinearCombination<cutlass_t, 8, float, float>, // Epilogue
|
108 |
+
GemmIdentityThreadblockSwizzle<1>, // Swizzling Operator
|
109 |
+
2 // Stages
|
110 |
+
>::GemmKernel;
|
111 |
+
|
112 |
+
using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
|
113 |
+
typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0);
|
114 |
+
|
115 |
+
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
116 |
+
typename GemmGrouped::Arguments args(all_problems, num_problems, 512,
|
117 |
+
epilogue_op, ptr_X, ptr_W, ptr_Y,
|
118 |
+
ptr_Y, ld_X, ld_W, ld_Y, ld_Y);
|
119 |
+
|
120 |
+
GemmGrouped gemm;
|
121 |
+
auto status = gemm.initialize(args, nullptr, stream);
|
122 |
+
if (status != cutlass::Status::kSuccess) {
|
123 |
+
fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
|
124 |
+
cutlassGetStatusString(status));
|
125 |
+
return false;
|
126 |
+
}
|
127 |
+
status = gemm.run(stream);
|
128 |
+
if (status != cutlass::Status::kSuccess) {
|
129 |
+
fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n",
|
130 |
+
cutlassGetStatusString(status));
|
131 |
+
return false;
|
132 |
+
}
|
133 |
+
} else {
|
134 |
+
// Shrink
|
135 |
+
using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped<
|
136 |
+
cutlass_t, // Element A
|
137 |
+
cutlass::layout::RowMajor, // Layout A
|
138 |
+
cutlass::ComplexTransform::kNone, //
|
139 |
+
8, // Granularity A
|
140 |
+
cutlass_t, // Element B
|
141 |
+
cutlass::layout::RowMajor, // Layout B
|
142 |
+
cutlass::ComplexTransform::kNone, //
|
143 |
+
8, // Granularity B
|
144 |
+
cutlass_t, // Element C&D
|
145 |
+
cutlass::layout::RowMajor, // Layout C&D
|
146 |
+
float, // Element Accumulator
|
147 |
+
cutlass::arch::OpClassTensorOp, // Operator Class Tag
|
148 |
+
cutlass::arch::Sm80, // Architecture
|
149 |
+
cutlass::gemm::GemmShape<16, 64, 64>, // Thread Block Shape
|
150 |
+
cutlass::gemm::GemmShape<16, 16, 64>, // Warp Shape
|
151 |
+
cutlass::gemm::GemmShape<16, 8, 16>, // Instruction Shape
|
152 |
+
LinearCombination<cutlass_t, 4, float, float>, // Epilogue
|
153 |
+
GemmIdentityThreadblockSwizzle<2>, // Swizzling Operator
|
154 |
+
2 // Stages
|
155 |
+
>::GemmKernel;
|
156 |
+
|
157 |
+
using EpilogueOutputOp = typename GemmKernel::Epilogue::OutputOp;
|
158 |
+
typename EpilogueOutputOp::Params epilogue_op(1.0, 1.0);
|
159 |
+
|
160 |
+
using GemmGrouped = cutlass::gemm::device::GemmGrouped<GemmKernel>;
|
161 |
+
typename GemmGrouped::Arguments args(all_problems, num_problems, 512,
|
162 |
+
epilogue_op, ptr_X, ptr_W, ptr_Y,
|
163 |
+
ptr_Y, ld_X, ld_W, ld_Y, ld_Y);
|
164 |
+
|
165 |
+
GemmGrouped gemm;
|
166 |
+
auto status = gemm.initialize(args, nullptr, stream);
|
167 |
+
if (status != cutlass::Status::kSuccess) {
|
168 |
+
fprintf(stderr, "sgmv_cutlass gemm.initialize failed: %s\n",
|
169 |
+
cutlassGetStatusString(status));
|
170 |
+
return false;
|
171 |
+
}
|
172 |
+
status = gemm.run(stream);
|
173 |
+
if (status != cutlass::Status::kSuccess) {
|
174 |
+
fprintf(stderr, "sgmv_cutlass gemm.run failed: %s\n",
|
175 |
+
cutlassGetStatusString(status));
|
176 |
+
return false;
|
177 |
+
}
|
178 |
+
}
|
179 |
+
return true;
|
180 |
+
}
|
sgmv_flashinfer/sgmv_all.cu
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <cuda_bf16.h>
|
2 |
+
#include <cuda_fp16.h>
|
3 |
+
#include <cuda_runtime.h>
|
4 |
+
|
5 |
+
#include <cstdint>
|
6 |
+
|
7 |
+
#include "sgmv_config.h"
|
8 |
+
#include "sgmv_flashinfer.cuh"
|
9 |
+
|
10 |
+
template <typename T, uint32_t d_out>
|
11 |
+
bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp,
|
12 |
+
uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream) {
|
13 |
+
static_assert(d_out % 16 == 0);
|
14 |
+
|
15 |
+
constexpr uint32_t num_warps = 4;
|
16 |
+
constexpr uint32_t num_stages = 2;
|
17 |
+
constexpr uint32_t num_k_frags_per_stage = 8;
|
18 |
+
constexpr uint32_t num_blocks_n = d_out / 16;
|
19 |
+
uint32_t smem = num_stages * sizeof(T) * num_k_frags_per_stage * 16 * 16 *
|
20 |
+
(num_warps + num_blocks_n);
|
21 |
+
auto cooperative_kernel =
|
22 |
+
flashinfer::sgmv::sgmv_shrink<true, T, int, num_warps, d_out>;
|
23 |
+
auto kernel = flashinfer::sgmv::sgmv_shrink<false, T, int, num_warps, d_out>;
|
24 |
+
|
25 |
+
int dev_id = 0;
|
26 |
+
int num_blocks_per_sm = 0;
|
27 |
+
int num_sm = 0;
|
28 |
+
bool use_cooperative = true;
|
29 |
+
cudaGetDevice(&dev_id);
|
30 |
+
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id);
|
31 |
+
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
32 |
+
&num_blocks_per_sm, cooperative_kernel, num_warps * 32, smem);
|
33 |
+
|
34 |
+
const uint32_t max_grid_size = num_sm * num_blocks_per_sm;
|
35 |
+
|
36 |
+
uint32_t chunk_size = 256;
|
37 |
+
uint32_t num_chunks = (d_in + chunk_size - 1) / chunk_size;
|
38 |
+
if (num_chunks * num_problems > max_grid_size) {
|
39 |
+
use_cooperative = false;
|
40 |
+
chunk_size = d_in;
|
41 |
+
num_chunks = 1;
|
42 |
+
}
|
43 |
+
|
44 |
+
dim3 nthrs(32, num_warps);
|
45 |
+
dim3 nblks(num_chunks, num_problems);
|
46 |
+
|
47 |
+
void* args[] = {(void*)&y, (void*)&x, (void*)&w,
|
48 |
+
(void*)&s_start, (void*)&s_end, (void*)&tmp, (void*)&num_problems,
|
49 |
+
(void*)&d_in, (void*)&layer_idx, (void*)&chunk_size};
|
50 |
+
|
51 |
+
cudaError_t status;
|
52 |
+
if (use_cooperative) {
|
53 |
+
if (smem > 46 * 1024) {
|
54 |
+
cudaFuncSetAttribute(cooperative_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
55 |
+
}
|
56 |
+
status = cudaLaunchCooperativeKernel((void*)cooperative_kernel, nblks,
|
57 |
+
nthrs, args, smem, stream);
|
58 |
+
} else {
|
59 |
+
if (smem > 46 * 1024) {
|
60 |
+
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem);
|
61 |
+
}
|
62 |
+
status = cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem, stream);
|
63 |
+
}
|
64 |
+
return status == cudaSuccess;
|
65 |
+
}
|
66 |
+
|
67 |
+
#define INST(T, d_out) \
|
68 |
+
template bool sgmv_shrink<T, d_out>(T * y, T * x, T * *w, int32_t * s_start, int32_t * s_end, \
|
69 |
+
void* tmp, uint32_t num_problems, \
|
70 |
+
uint32_t d_in, uint32_t layer_idx, cudaStream_t stream);
|
71 |
+
|
72 |
+
FOR_SGMV_NARROW(INST, nv_half);
|
73 |
+
FOR_SGMV_NARROW(INST, nv_bfloat16);
|
sgmv_flashinfer/sgmv_config.h
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <cstdint>
|
3 |
+
|
4 |
+
template <typename T, uint32_t d_out>
|
5 |
+
bool sgmv_shrink(T* y, T* x, T** w, int32_t* s_start, int32_t* s_end, void* tmp,
|
6 |
+
uint32_t num_problems, uint32_t d_in, uint32_t layer_idx, cudaStream_t stream);
|
7 |
+
|
8 |
+
// clang-format off
|
9 |
+
|
10 |
+
#define FOR_SGMV_NARROW(f, T) \
|
11 |
+
f(T, 16) \
|
12 |
+
f(T, 32) \
|
13 |
+
f(T, 64) \
|
14 |
+
f(T, 96) \
|
15 |
+
f(T, 128)
|
16 |
+
|
17 |
+
// clang-format on
|
sgmv_flashinfer/sgmv_flashinfer.cuh
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
#include <cooperative_groups.h>
|
3 |
+
|
4 |
+
#include "flashinfer/cp_async.cuh"
|
5 |
+
#include "flashinfer/mma.cuh"
|
6 |
+
#include "flashinfer/permuted_smem.cuh"
|
7 |
+
#include "flashinfer/vec_dtypes.cuh"
|
8 |
+
|
9 |
+
namespace flashinfer {
|
10 |
+
|
11 |
+
namespace sgmv {
|
12 |
+
|
13 |
+
template <bool cooperative, typename T, typename IdType, uint32_t num_warps,
|
14 |
+
uint32_t d_out>
|
15 |
+
__global__ void sgmv_shrink(T* y, T* x, T** w, IdType* s_starts, IdType* s_ends, float* tmp,
|
16 |
+
uint32_t num_problems, uint32_t d_in,
|
17 |
+
uint32_t layer_idx, uint32_t chunk_size) {
|
18 |
+
auto block = cooperative_groups::this_thread_block();
|
19 |
+
auto grid = cooperative_groups::this_grid();
|
20 |
+
constexpr auto fill_mode = cp_async::SharedMemFillMode::kFillZero;
|
21 |
+
const uint32_t problem_id = blockIdx.y;
|
22 |
+
const uint32_t bx = blockIdx.x;
|
23 |
+
|
24 |
+
constexpr uint32_t num_stages = 2;
|
25 |
+
constexpr uint32_t num_k_frags = 8;
|
26 |
+
constexpr uint32_t num_cells_k = (num_k_frags * 16) / cell_capacity<T>();
|
27 |
+
constexpr uint32_t num_blocks_n = d_out / 16;
|
28 |
+
const uint32_t num_chunks = gridDim.x;
|
29 |
+
const uint32_t chunk_start = chunk_size * bx;
|
30 |
+
const uint32_t num_iterations =
|
31 |
+
(chunk_size + (num_k_frags * 16 - 1)) / (num_k_frags * 16);
|
32 |
+
constexpr uint32_t num_cells_n =
|
33 |
+
(d_out < 32 ? 32 : d_out) / cell_capacity<T>();
|
34 |
+
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
35 |
+
|
36 |
+
extern __shared__ uint8_t smem[];
|
37 |
+
|
38 |
+
smem_t x_smem[2]{smem, smem + sizeof(T) * num_warps * 16 * 16 * num_k_frags};
|
39 |
+
smem_t w_smem[2]{smem + sizeof(T) * 2 * num_warps * 16 * 16 * num_k_frags,
|
40 |
+
smem + sizeof(T) * 16 * 16 * num_k_frags *
|
41 |
+
(2 * num_warps + num_blocks_n)};
|
42 |
+
smem_t y_smem(smem);
|
43 |
+
|
44 |
+
uint32_t x_frag[num_k_frags][4];
|
45 |
+
uint32_t w_frag[num_k_frags][num_blocks_n][4];
|
46 |
+
float y_frag[num_blocks_n][8];
|
47 |
+
|
48 |
+
const uint32_t s_start = s_starts[problem_id], s_end = s_ends[problem_id];
|
49 |
+
const uint32_t num_steps = (s_start < s_end) ? (s_end - s_start + (num_warps * 16 - 1)) / (num_warps * 16) : 0;
|
50 |
+
for (uint32_t i = 0; i < num_steps; ++i) {
|
51 |
+
// init y_frag
|
52 |
+
if (bx == 0) {
|
53 |
+
if constexpr (num_blocks_n == 1) {
|
54 |
+
uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2;
|
55 |
+
T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity<T>();
|
56 |
+
auto offset =
|
57 |
+
smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 2, tx % 2);
|
58 |
+
y_smem.load_128b_async<fill_mode>(offset, y_ptr, row_idx < s_end);
|
59 |
+
} else {
|
60 |
+
uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
|
61 |
+
T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity<T>();
|
62 |
+
auto offset =
|
63 |
+
smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 4, tx % 4);
|
64 |
+
#pragma unroll
|
65 |
+
for (uint32_t j = 0; j < 2; ++j) {
|
66 |
+
#pragma unroll
|
67 |
+
for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) {
|
68 |
+
y_smem.load_128b_async<fill_mode>(offset, y_ptr, row_idx < s_end);
|
69 |
+
y_ptr += 4 * cell_capacity<T>();
|
70 |
+
offset += 8;
|
71 |
+
}
|
72 |
+
row_idx += 8;
|
73 |
+
y_ptr += 8 * d_out - 2 * num_blocks_n * cell_capacity<T>();
|
74 |
+
offset += 8 * num_cells_n - 4 * num_blocks_n;
|
75 |
+
}
|
76 |
+
}
|
77 |
+
cp_async::commit_group();
|
78 |
+
cp_async::wait_group<0>();
|
79 |
+
block.sync();
|
80 |
+
|
81 |
+
auto offset =
|
82 |
+
smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx % 16, tx / 16);
|
83 |
+
#pragma unroll
|
84 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
85 |
+
uint32_t tmp[4];
|
86 |
+
y_smem.ldmatrix_m8n8x4(offset, tmp);
|
87 |
+
vec_cast<float, T, 8>(y_frag[fn], (T*)tmp);
|
88 |
+
offset = (offset ^ 0x2) + (fn & 0x1) * 8;
|
89 |
+
}
|
90 |
+
} else {
|
91 |
+
#pragma unroll
|
92 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
93 |
+
#pragma unroll
|
94 |
+
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
95 |
+
y_frag[fn][reg_id] = 0.f;
|
96 |
+
}
|
97 |
+
}
|
98 |
+
}
|
99 |
+
|
100 |
+
// preload x_smem, w_smem
|
101 |
+
#pragma unroll
|
102 |
+
for (uint32_t iter = 0; iter < num_stages; ++iter) {
|
103 |
+
uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
|
104 |
+
T* x_ptr = x + row_idx * d_in + chunk_start +
|
105 |
+
(2 * num_k_frags * iter + tx % 4) * cell_capacity<T>();
|
106 |
+
T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size);
|
107 |
+
auto offset =
|
108 |
+
smem_t::get_permuted_offset<num_cells_k>(ty * 16 + tx / 4, tx % 4);
|
109 |
+
// pre-load x_smem, w_smem
|
110 |
+
#pragma unroll
|
111 |
+
for (uint32_t j = 0; j < 2; ++j) {
|
112 |
+
#pragma unroll
|
113 |
+
for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) {
|
114 |
+
x_smem[iter].load_128b_async<fill_mode>(
|
115 |
+
offset, x_ptr, row_idx < s_end && x_ptr < x_ptr_max);
|
116 |
+
x_ptr += 4 * cell_capacity<T>();
|
117 |
+
offset += 8;
|
118 |
+
}
|
119 |
+
row_idx += 8;
|
120 |
+
x_ptr += 8 * d_in - 2 * cell_capacity<T>() * num_k_frags;
|
121 |
+
x_ptr_max += 8 * d_in;
|
122 |
+
offset += 8 * num_cells_k - 4 * num_k_frags;
|
123 |
+
}
|
124 |
+
row_idx -= 8;
|
125 |
+
|
126 |
+
static_assert(num_k_frags % (num_warps * 2) == 0);
|
127 |
+
constexpr uint32_t num_fko_iters_per_warp = num_k_frags / (num_warps * 2);
|
128 |
+
#pragma unroll
|
129 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
130 |
+
T* w_ptr = w[problem_id] + layer_idx * d_in * d_out +
|
131 |
+
(fn * 16 + tx / 4) * d_in + chunk_start +
|
132 |
+
(2 * num_k_frags * iter + ty * num_fko_iters_per_warp * 4 +
|
133 |
+
tx % 4) *
|
134 |
+
cell_capacity<T>();
|
135 |
+
T* w_ptr_max =
|
136 |
+
w[problem_id] + layer_idx * d_in * d_out +
|
137 |
+
min((fn * 16 + tx / 4 + 1) * d_in,
|
138 |
+
(fn * 16 + tx / 4) * d_in + chunk_start + chunk_size);
|
139 |
+
auto offset = smem_t::get_permuted_offset<num_cells_k>(
|
140 |
+
fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4);
|
141 |
+
#pragma unroll
|
142 |
+
for (uint32_t j = 0; j < 2; ++j) {
|
143 |
+
#pragma unroll
|
144 |
+
for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) {
|
145 |
+
w_smem[iter].load_128b_async<fill_mode>(offset, w_ptr,
|
146 |
+
w_ptr < w_ptr_max);
|
147 |
+
w_ptr += 4 * cell_capacity<T>();
|
148 |
+
offset += 8;
|
149 |
+
}
|
150 |
+
w_ptr += 8 * d_in - 4 * cell_capacity<T>() * num_fko_iters_per_warp;
|
151 |
+
w_ptr_max += 8 * d_in;
|
152 |
+
offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp;
|
153 |
+
}
|
154 |
+
}
|
155 |
+
cp_async::commit_group();
|
156 |
+
}
|
157 |
+
|
158 |
+
#pragma unroll 1
|
159 |
+
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
160 |
+
const uint32_t stage_idx = iter % 2;
|
161 |
+
cp_async::wait_group<1>();
|
162 |
+
block.sync();
|
163 |
+
|
164 |
+
auto offset =
|
165 |
+
smem_t::get_permuted_offset<num_cells_k>(ty * 16 + tx % 16, tx / 16);
|
166 |
+
#pragma unroll
|
167 |
+
for (uint32_t fk = 0; fk < num_k_frags; ++fk) {
|
168 |
+
x_smem[stage_idx].ldmatrix_m8n8x4(offset, x_frag[fk]);
|
169 |
+
offset = (offset ^ 0x2) + (fk & 0x1) * 8;
|
170 |
+
}
|
171 |
+
|
172 |
+
#pragma unroll
|
173 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
174 |
+
auto offset = smem_t::get_permuted_offset<num_cells_k>(
|
175 |
+
fn * 16 + 8 * (tx / 16) + tx % 8, (tx % 16) / 8);
|
176 |
+
#pragma unroll
|
177 |
+
for (uint32_t fk = 0; fk < num_k_frags; ++fk) {
|
178 |
+
w_smem[stage_idx].ldmatrix_m8n8x4(offset, w_frag[fk][fn]);
|
179 |
+
offset = (offset ^ 0x2) + (fk & 0x1) * 8;
|
180 |
+
}
|
181 |
+
offset += 16 * num_cells_k - 4 * num_k_frags;
|
182 |
+
}
|
183 |
+
|
184 |
+
// compute y_frag
|
185 |
+
#pragma unroll
|
186 |
+
for (uint32_t fk = 0; fk < num_k_frags; ++fk) {
|
187 |
+
#pragma unroll
|
188 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
189 |
+
mma::mma_sync_m16n16k16_row_col_f16f16f32<T>(y_frag[fn], x_frag[fk],
|
190 |
+
w_frag[fk][fn]);
|
191 |
+
}
|
192 |
+
}
|
193 |
+
block.sync();
|
194 |
+
|
195 |
+
// load next stage
|
196 |
+
if (iter + num_stages < num_iterations) {
|
197 |
+
uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
|
198 |
+
T* x_ptr = x + row_idx * d_in + chunk_start +
|
199 |
+
(2 * num_k_frags * (iter + num_stages) + tx % 4) *
|
200 |
+
cell_capacity<T>();
|
201 |
+
T* x_ptr_max = x + row_idx * d_in + min(d_in, chunk_start + chunk_size);
|
202 |
+
auto offset =
|
203 |
+
smem_t::get_permuted_offset<num_cells_k>(ty * 16 + tx / 4, tx % 4);
|
204 |
+
// pre-load x_smem, w_smem
|
205 |
+
#pragma unroll
|
206 |
+
for (uint32_t j = 0; j < 2; ++j) {
|
207 |
+
#pragma unroll
|
208 |
+
for (uint32_t fko = 0; fko < num_k_frags / 2; ++fko) {
|
209 |
+
x_smem[stage_idx].load_128b_async<fill_mode>(
|
210 |
+
offset, x_ptr, row_idx < s_end && x_ptr < x_ptr_max);
|
211 |
+
x_ptr += 4 * cell_capacity<T>();
|
212 |
+
offset += 8;
|
213 |
+
}
|
214 |
+
row_idx += 8;
|
215 |
+
x_ptr += 8 * d_in - 2 * cell_capacity<T>() * num_k_frags;
|
216 |
+
x_ptr_max += 8 * d_in;
|
217 |
+
offset += 8 * num_cells_k - 4 * num_k_frags;
|
218 |
+
}
|
219 |
+
row_idx -= 8;
|
220 |
+
|
221 |
+
constexpr uint32_t num_fko_iters_per_warp =
|
222 |
+
num_k_frags / (num_warps * 2);
|
223 |
+
#pragma unroll
|
224 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
225 |
+
T* w_ptr = w[problem_id] + layer_idx * d_in * d_out +
|
226 |
+
(fn * 16 + tx / 4) * d_in + chunk_start +
|
227 |
+
(2 * num_k_frags * (iter + num_stages) +
|
228 |
+
ty * num_fko_iters_per_warp * 4 + tx % 4) *
|
229 |
+
cell_capacity<T>();
|
230 |
+
T* w_ptr_max =
|
231 |
+
w[problem_id] + layer_idx * d_in * d_out +
|
232 |
+
min((fn * 16 + tx / 4 + 1) * d_in,
|
233 |
+
(fn * 16 + tx / 4) * d_in + chunk_start + chunk_size);
|
234 |
+
auto offset = smem_t::get_permuted_offset<num_cells_k>(
|
235 |
+
fn * 16 + tx / 4, ty * num_fko_iters_per_warp * 4 + tx % 4);
|
236 |
+
#pragma unroll
|
237 |
+
for (uint32_t j = 0; j < 2; ++j) {
|
238 |
+
#pragma unroll
|
239 |
+
for (uint32_t fko = 0; fko < num_fko_iters_per_warp; ++fko) {
|
240 |
+
w_smem[stage_idx].load_128b_async<fill_mode>(offset, w_ptr,
|
241 |
+
w_ptr < w_ptr_max);
|
242 |
+
w_ptr += 4 * cell_capacity<T>();
|
243 |
+
offset += 8;
|
244 |
+
}
|
245 |
+
w_ptr += 8 * d_in - 4 * cell_capacity<T>() * num_fko_iters_per_warp;
|
246 |
+
w_ptr_max += 8 * d_in;
|
247 |
+
offset += 8 * num_cells_k - 8 * num_fko_iters_per_warp;
|
248 |
+
}
|
249 |
+
}
|
250 |
+
}
|
251 |
+
cp_async::commit_group();
|
252 |
+
}
|
253 |
+
cp_async::wait_group<0>();
|
254 |
+
block.sync();
|
255 |
+
|
256 |
+
if constexpr (cooperative) {
|
257 |
+
#pragma unroll
|
258 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
259 |
+
vec_t<float, 8>::memcpy(
|
260 |
+
tmp + (fn * grid.size() +
|
261 |
+
(problem_id * num_chunks + bx) * block.num_threads() +
|
262 |
+
block.thread_rank()) *
|
263 |
+
8,
|
264 |
+
y_frag[fn]);
|
265 |
+
}
|
266 |
+
grid.sync();
|
267 |
+
|
268 |
+
#pragma unroll
|
269 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
270 |
+
#pragma unroll
|
271 |
+
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
272 |
+
y_frag[fn][reg_id] = 0.f;
|
273 |
+
}
|
274 |
+
for (uint32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
|
275 |
+
vec_t<float, 8> y_other;
|
276 |
+
y_other.load(tmp + (fn * grid.size() +
|
277 |
+
(problem_id * num_chunks + chunk_idx) *
|
278 |
+
block.num_threads() +
|
279 |
+
block.thread_rank()) *
|
280 |
+
8);
|
281 |
+
#pragma unroll
|
282 |
+
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
283 |
+
y_frag[fn][reg_id] += y_other[reg_id];
|
284 |
+
}
|
285 |
+
}
|
286 |
+
}
|
287 |
+
}
|
288 |
+
|
289 |
+
if (bx == 0) {
|
290 |
+
// store y_frag
|
291 |
+
auto offset =
|
292 |
+
smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 4, 0);
|
293 |
+
#pragma unroll
|
294 |
+
for (uint32_t fn = 0; fn < num_blocks_n; ++fn) {
|
295 |
+
vec_cast<T, float, 2>((T*)(y_smem.base + offset) + (tx % 4) * 2,
|
296 |
+
&y_frag[fn][0]);
|
297 |
+
vec_cast<T, float, 2>(
|
298 |
+
(T*)(y_smem.base + offset + 8 * num_cells_n) + (tx % 4) * 2,
|
299 |
+
&y_frag[fn][2]);
|
300 |
+
vec_cast<T, float, 2>((T*)(y_smem.base + (offset ^ 0x1)) + (tx % 4) * 2,
|
301 |
+
&y_frag[fn][4]);
|
302 |
+
vec_cast<T, float, 2>(
|
303 |
+
(T*)(y_smem.base + (offset ^ 0x1) + 8 * num_cells_n) + (tx % 4) * 2,
|
304 |
+
&y_frag[fn][6]);
|
305 |
+
offset = (offset ^ 0x2) + (fn & 0x1) * 8;
|
306 |
+
}
|
307 |
+
|
308 |
+
// store y
|
309 |
+
if constexpr (num_blocks_n == 1) {
|
310 |
+
uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 2;
|
311 |
+
T* y_ptr = y + row_idx * d_out + (tx % 2) * cell_capacity<T>();
|
312 |
+
auto offset =
|
313 |
+
smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 2, tx % 2);
|
314 |
+
if (row_idx < s_end) {
|
315 |
+
y_smem.store_128b(offset, y_ptr);
|
316 |
+
}
|
317 |
+
} else {
|
318 |
+
uint32_t row_idx = s_start + (i * num_warps + ty) * 16 + tx / 4;
|
319 |
+
T* y_ptr = y + row_idx * d_out + (tx % 4) * cell_capacity<T>();
|
320 |
+
auto offset =
|
321 |
+
smem_t::get_permuted_offset<num_cells_n>(ty * 16 + tx / 4, tx % 4);
|
322 |
+
#pragma unroll
|
323 |
+
for (uint32_t j = 0; j < 2; ++j) {
|
324 |
+
#pragma unroll
|
325 |
+
for (uint32_t fno = 0; fno < num_blocks_n / 2; ++fno) {
|
326 |
+
if (row_idx < s_end) {
|
327 |
+
y_smem.store_128b(offset, y_ptr);
|
328 |
+
}
|
329 |
+
y_ptr += 4 * cell_capacity<T>();
|
330 |
+
offset += 8;
|
331 |
+
}
|
332 |
+
row_idx += 8;
|
333 |
+
y_ptr += 8 * d_out - 2 * num_blocks_n * cell_capacity<T>();
|
334 |
+
offset += 8 * num_cells_n - 4 * num_blocks_n;
|
335 |
+
}
|
336 |
+
}
|
337 |
+
}
|
338 |
+
}
|
339 |
+
|
340 |
+
// handle the case where one of the segments needs more steps than this one
|
341 |
+
// to avoid deadlock
|
342 |
+
if constexpr (cooperative) {
|
343 |
+
uint32_t max_segment_size = 0;
|
344 |
+
for (uint32_t i = 0; i < num_problems; ++i) {
|
345 |
+
max_segment_size = max(max_segment_size, s_ends[i] - s_starts[i]);
|
346 |
+
}
|
347 |
+
|
348 |
+
const uint32_t max_steps = (max_segment_size + (num_warps * 16 - 1)) / (num_warps * 16);
|
349 |
+
for (uint32_t i = 0; i < max_steps - num_steps; ++i) {
|
350 |
+
grid.sync();
|
351 |
+
}
|
352 |
+
}
|
353 |
+
}
|
354 |
+
|
355 |
+
} // namespace sgmv
|
356 |
+
} // namespace flashinfer
|
tests/test_sgmv.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from punica_sgmv import (
|
7 |
+
get_tmp_tensors,
|
8 |
+
lora_a_sgmv_cutlass,
|
9 |
+
lora_b_sgmv_cutlass,
|
10 |
+
pad_rank,
|
11 |
+
use_cutlass_shrink,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def lora_ref_impl(
|
16 |
+
y: torch.Tensor,
|
17 |
+
x: torch.Tensor,
|
18 |
+
wa: List[torch.Tensor],
|
19 |
+
wb: List[torch.Tensor],
|
20 |
+
s_start: torch.IntTensor,
|
21 |
+
s_end: torch.IntTensor,
|
22 |
+
layer_idx: int,
|
23 |
+
lora_rank: int,
|
24 |
+
):
|
25 |
+
for i in range(len(wa)):
|
26 |
+
if s_end[i] - s_start[i] <= 0:
|
27 |
+
continue
|
28 |
+
|
29 |
+
xi = x[s_start[i]:s_end[i]]
|
30 |
+
wai = wa[i][layer_idx, :, :]
|
31 |
+
wbi = wb[i][layer_idx, :, :]
|
32 |
+
|
33 |
+
if not use_cutlass_shrink(lora_rank):
|
34 |
+
wai = wai.t()
|
35 |
+
|
36 |
+
yi = y[s_start[i]:s_end[i]]
|
37 |
+
tmp = (xi @ wai)
|
38 |
+
y[s_start[i]:s_end[i]] = (yi + tmp @ wbi)
|
39 |
+
|
40 |
+
|
41 |
+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
42 |
+
@pytest.mark.parametrize("segments", [
|
43 |
+
([0, 2], [1, 3]),
|
44 |
+
([0, -1], [1, -1]),
|
45 |
+
])
|
46 |
+
@pytest.mark.parametrize("lora_rank", [8, 16, 32, 64, 128])
|
47 |
+
def test_add_lora_sgmv(lora_rank: int, segments: Tuple[List[int], List[int]]):
|
48 |
+
torch.manual_seed(42)
|
49 |
+
|
50 |
+
B = 3
|
51 |
+
H = 1024
|
52 |
+
r = lora_rank
|
53 |
+
nlayers = 2
|
54 |
+
|
55 |
+
device = torch.device("cuda:0")
|
56 |
+
|
57 |
+
y = torch.zeros((B, H), dtype=torch.float16, device=device)
|
58 |
+
x = torch.randn((B, H), dtype=torch.float16, device=device)
|
59 |
+
wa = torch.randn(nlayers, r, H, dtype=torch.float16, device=device)
|
60 |
+
if use_cutlass_shrink(r):
|
61 |
+
# cutlass uses (H, r) layout
|
62 |
+
wa = wa.transpose(1, 2).contiguous()
|
63 |
+
|
64 |
+
# TODO(travis): transpose (r, H) -> (H, r) when not using cutlass
|
65 |
+
wb = torch.randn(nlayers, r, H, dtype=torch.float16, device=device)
|
66 |
+
|
67 |
+
s1, s2 = segments
|
68 |
+
s_start = torch.tensor(s1, dtype=torch.int32, device=device)
|
69 |
+
s_end = torch.tensor(s2, dtype=torch.int32, device=device)
|
70 |
+
|
71 |
+
wa_list = [wa if y - x > 0 else None for x, y in zip(s1, s2)]
|
72 |
+
wb_list = [wb if y - x > 0 else None for x, y in zip(s1, s2)]
|
73 |
+
|
74 |
+
wa_ptr = torch.tensor([wa.data_ptr() if wa is not None else 0 for wa in wa_list], dtype=torch.int64, device=device)
|
75 |
+
wb_ptr = torch.tensor([wb.data_ptr() if wb is not None else 0 for wb in wb_list], dtype=torch.int64, device=device)
|
76 |
+
|
77 |
+
layer_idx = 0
|
78 |
+
|
79 |
+
y_ref = y.clone()
|
80 |
+
lora_ref_impl(y_ref, x, wa_list, wb_list, s_start, s_end, layer_idx, r)
|
81 |
+
|
82 |
+
tmp_shrink, tmp_expand = get_tmp_tensors(wa_ptr.size(0), r, x.device)
|
83 |
+
y_ours = torch.zeros((B, H), dtype=torch.float16, device=device)
|
84 |
+
|
85 |
+
v = lora_a_sgmv_cutlass(x, tmp_shrink, wa_ptr, s_start, s_end, layer_idx, r)
|
86 |
+
lora_b_sgmv_cutlass(y_ours, v, tmp_expand, wb_ptr, s_start, s_end, layer_idx)
|
87 |
+
|
88 |
+
assert torch.allclose(y_ref, y_ours, rtol=1e-2, atol=1e-3)
|
89 |
+
|
90 |
+
# graph trace
|
91 |
+
tmp_shrink, tmp_expand = get_tmp_tensors(wa_ptr.size(0), r, x.device)
|
92 |
+
y_ours_graph = torch.zeros((B, H), dtype=torch.float16, device=device)
|
93 |
+
|
94 |
+
torch.cuda.synchronize(device)
|
95 |
+
graph = torch.cuda.CUDAGraph()
|
96 |
+
with torch.cuda.graph(graph, pool=None):
|
97 |
+
v = lora_a_sgmv_cutlass(x, tmp_shrink, wa_ptr, s_start, s_end, layer_idx, r)
|
98 |
+
lora_b_sgmv_cutlass(y_ours_graph, v, tmp_expand, wb_ptr, s_start, s_end, layer_idx)
|
99 |
+
|
100 |
+
torch.cuda.synchronize(device)
|
101 |
+
graph.replay()
|
102 |
+
|
103 |
+
assert torch.allclose(y_ours, y_ours_graph, rtol=1e-2, atol=1e-3)
|
104 |
+
|
105 |
+
|
106 |
+
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
|
107 |
+
@pytest.mark.parametrize("lora_rank", [8, 16, 32, 64, 128])
|
108 |
+
def test_pad_rank(lora_rank: int, world_size: int):
|
109 |
+
bs = 8
|
110 |
+
h = 1024
|
111 |
+
x = torch.randn((bs, h), dtype=torch.float16)
|
112 |
+
|
113 |
+
lora_a = torch.randn((h, lora_rank), dtype=torch.float16)
|
114 |
+
lora_b = torch.randn((lora_rank, h), dtype=torch.float16)
|
115 |
+
|
116 |
+
lora_a_padded = pad_rank(lora_a, dim=1, world_size=world_size)
|
117 |
+
lora_b_padded = pad_rank(lora_b, dim=0, world_size=world_size)
|
118 |
+
|
119 |
+
assert lora_a_padded.size(1) == lora_b_padded.size(0)
|
120 |
+
assert lora_a_padded.size(1) >= lora_a.size(1)
|
121 |
+
assert lora_b_padded.size(0) >= lora_b.size(0)
|
122 |
+
|
123 |
+
expected = x @ lora_a @ lora_b
|
124 |
+
actual = x @ lora_a_padded @ lora_b_padded
|
125 |
+
assert torch.allclose(expected, actual)
|
torch-ext/punica_sgmv/__init__.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
from functools import lru_cache
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from ._ops import ops
|
8 |
+
|
9 |
+
MIN_SGMV_RANK = 8
|
10 |
+
MIN_RANK_CUSTOM = 16
|
11 |
+
MAX_RANK_CUSTOM = 128
|
12 |
+
SGMV_BLOCK_SIZE = 16
|
13 |
+
BGMV_MAX_RANK = 128
|
14 |
+
|
15 |
+
def orient_for_rank(t: torch.Tensor, rank: int) -> torch.Tensor:
|
16 |
+
if MIN_RANK_CUSTOM <= rank <= MAX_RANK_CUSTOM:
|
17 |
+
return t.transpose(0, 1)
|
18 |
+
return t
|
19 |
+
|
20 |
+
def add_lora_sgmv_cutlass(
|
21 |
+
y: torch.Tensor,
|
22 |
+
x: torch.Tensor,
|
23 |
+
wa_ptr: torch.Tensor,
|
24 |
+
wb_ptr: torch.Tensor,
|
25 |
+
s_start: torch.Tensor,
|
26 |
+
s_end: torch.Tensor,
|
27 |
+
layer_idx: int,
|
28 |
+
lora_rank: int,
|
29 |
+
):
|
30 |
+
"""
|
31 |
+
Semantics:
|
32 |
+
y[s[i]:s[i+1]] += x[s[i]:s[i+1]] @ deref(wa_ptr[i]).T @ deref(wb_ptr[i])
|
33 |
+
|
34 |
+
Args:
|
35 |
+
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
36 |
+
x: Shape: `[B, H1]`. Input vectors.
|
37 |
+
wa_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
38 |
+
Weight matrix shape: `[num_layers, R, H1]`.
|
39 |
+
wb_ptr: Shape: `[S]`. DType: torch.int64. Pointer to the weight matrices.\
|
40 |
+
Weight matrix shape: `[num_layers, R, H2]`.
|
41 |
+
s_start: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices start indices.
|
42 |
+
s_end: Shape: `[S]`, DType: torch.int32. Indptr of the weight matrices end indices.
|
43 |
+
layer_idx: Layer index of the weight matrices.
|
44 |
+
"""
|
45 |
+
if lora_rank < MIN_RANK_CUSTOM or lora_rank > MAX_RANK_CUSTOM:
|
46 |
+
# Custom SGMV shrink only supports rank 16, 32, 64, 128
|
47 |
+
_add_lora_sgmv_cutlass_legacy(y, x, wa_ptr, wb_ptr, s_start, s_end, layer_idx, lora_rank)
|
48 |
+
return
|
49 |
+
|
50 |
+
tmp1 = torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=x.device)
|
51 |
+
tmp2_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
52 |
+
tmp2 = torch.empty((tmp2_size,), dtype=torch.uint8, device=x.device)
|
53 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
54 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp1, layer_idx)
|
55 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp2, layer_idx)
|
56 |
+
|
57 |
+
def _add_lora_sgmv_cutlass_legacy(
|
58 |
+
y: torch.Tensor,
|
59 |
+
x: torch.Tensor,
|
60 |
+
wa_ptr: torch.Tensor,
|
61 |
+
wb_ptr: torch.Tensor,
|
62 |
+
s_start: torch.IntTensor,
|
63 |
+
s_end: torch.IntTensor,
|
64 |
+
layer_idx: int,
|
65 |
+
lora_rank: int,
|
66 |
+
):
|
67 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(wa_ptr.size(0))
|
68 |
+
tmp = torch.empty((tmp_size,), dtype=torch.uint8, device=x.device)
|
69 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
70 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
71 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
72 |
+
|
73 |
+
def lora_a_sgmv_cutlass(
|
74 |
+
x: torch.Tensor,
|
75 |
+
tmp: torch.Tensor,
|
76 |
+
wa_ptr: torch.Tensor,
|
77 |
+
s_start: torch.IntTensor,
|
78 |
+
s_end: torch.IntTensor,
|
79 |
+
layer_idx: int,
|
80 |
+
lora_rank: int,
|
81 |
+
) -> torch.Tensor:
|
82 |
+
v = torch.zeros((x.size(0), lora_rank), dtype=x.dtype, device=x.device)
|
83 |
+
if MIN_RANK_CUSTOM <= lora_rank <= MAX_RANK_CUSTOM:
|
84 |
+
ops.sgmv_shrink(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
85 |
+
else:
|
86 |
+
ops.sgmv_cutlass(v, x, wa_ptr, s_start, s_end, tmp, layer_idx)
|
87 |
+
return v
|
88 |
+
|
89 |
+
|
90 |
+
def lora_b_sgmv_cutlass(
|
91 |
+
y: torch.Tensor,
|
92 |
+
v: torch.Tensor,
|
93 |
+
tmp: torch.Tensor,
|
94 |
+
wb_ptr: torch.Tensor,
|
95 |
+
s_start: torch.IntTensor,
|
96 |
+
s_end: torch.IntTensor,
|
97 |
+
layer_idx: int,
|
98 |
+
):
|
99 |
+
ops.sgmv_cutlass(y, v, wb_ptr, s_start, s_end, tmp, layer_idx)
|
100 |
+
|
101 |
+
def add_lora_a_bgmv(
|
102 |
+
v: torch.Tensor,
|
103 |
+
x: torch.Tensor,
|
104 |
+
wa_T_all: torch.Tensor,
|
105 |
+
indicies: torch.LongTensor,
|
106 |
+
layer_idx: int,
|
107 |
+
):
|
108 |
+
ops.dispatch_bgmv(v, x, wa_T_all, indicies, layer_idx, 1.0)
|
109 |
+
|
110 |
+
|
111 |
+
def add_lora_b_bgmv(
|
112 |
+
y: torch.Tensor,
|
113 |
+
v: torch.Tensor,
|
114 |
+
wb_T_all: torch.Tensor,
|
115 |
+
indicies: torch.LongTensor,
|
116 |
+
layer_idx: int,
|
117 |
+
):
|
118 |
+
ops.dispatch_bgmv(y, v, wb_T_all, indicies, layer_idx, 1.0)
|
119 |
+
|
120 |
+
|
121 |
+
def pad_rank(t: torch.Tensor, dim: int, world_size: int) -> torch.Tensor:
|
122 |
+
"""Pad a tensor to the minimum rank for SGMV and the nearest multiple of the SGMV block size."""
|
123 |
+
# tensor parallelism will result in effective rank being divided by world_size,
|
124 |
+
# so we need to scale the min rank to offset that effect
|
125 |
+
min_rank = MIN_SGMV_RANK * world_size
|
126 |
+
return pad_to_min_rank(t, dim, min_rank)
|
127 |
+
|
128 |
+
def pad_to_min_rank(t: torch.Tensor, dim: int, min_rank: int) -> torch.Tensor:
|
129 |
+
# if we're at or below the min rank, pad up to the min rank
|
130 |
+
# otherwise, pad to the nearest multiple of the block size
|
131 |
+
current_rank = t.size(dim)
|
132 |
+
target_rank = (
|
133 |
+
min_rank
|
134 |
+
if current_rank <= min_rank
|
135 |
+
else (current_rank + SGMV_BLOCK_SIZE - 1) // SGMV_BLOCK_SIZE * SGMV_BLOCK_SIZE
|
136 |
+
)
|
137 |
+
if current_rank == target_rank:
|
138 |
+
return t
|
139 |
+
|
140 |
+
pad_size = target_rank - current_rank
|
141 |
+
|
142 |
+
# see complicatd pad syntax here: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
143 |
+
pad = [0, 0] * t.dim()
|
144 |
+
pad[(t.dim() - dim - 1) * 2 + 1] = pad_size
|
145 |
+
pad = tuple(pad)
|
146 |
+
|
147 |
+
return F.pad(t, pad, mode="constant", value=0.0)
|
148 |
+
|
149 |
+
def use_cutlass_shrink(lora_rank: int) -> bool:
|
150 |
+
return lora_rank < MIN_RANK_CUSTOM
|
151 |
+
|
152 |
+
@lru_cache(maxsize=1)
|
153 |
+
def get_tmp_tensor(device: torch.device) -> torch.Tensor:
|
154 |
+
return torch.empty((8 * 1024 * 1024,), dtype=torch.uint8, device=device)
|
155 |
+
|
156 |
+
@lru_cache(maxsize=32)
|
157 |
+
def get_tmp_tensor_for_size(size: int, device: torch.device) -> torch.Tensor:
|
158 |
+
tmp_size = ops.sgmv_cutlass_tmp_size(size)
|
159 |
+
return torch.empty((tmp_size,), dtype=torch.uint8, device=device)
|
160 |
+
|
161 |
+
def get_tmp_expand_size(size: int) -> int:
|
162 |
+
return ops.sgmv_cutlass_tmp_size(size)
|
163 |
+
|
164 |
+
|
165 |
+
def get_tmp_tensors(nsegments: int, lora_rank: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
166 |
+
if use_cutlass_shrink(lora_rank):
|
167 |
+
tmp = get_tmp_tensor_for_size(nsegments, device)
|
168 |
+
return tmp, tmp
|
169 |
+
else:
|
170 |
+
tmp_shrink = get_tmp_tensor(device)
|
171 |
+
tmp_expand = get_tmp_tensor_for_size(nsegments, device)
|
172 |
+
return tmp_shrink, tmp_expand
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("sgmv_shrink(Tensor! y, Tensor x, Tensor w_ptr, Tensor s_start, "
|
8 |
+
"Tensor s_end, Tensor! tmp, int layer_idx) -> ()");
|
9 |
+
ops.impl("sgmv_shrink", torch::kCUDA, &dispatch_sgmv_shrink);
|
10 |
+
|
11 |
+
ops.def("sgmv_cutlass(Tensor! y, Tensor x, Tensor w_ptr, Tensor s_start, "
|
12 |
+
"Tensor s_end, Tensor! tmp, int layer_idx) -> ()");
|
13 |
+
ops.impl("sgmv_cutlass", torch::kCUDA, &dispatch_sgmv_cutlass);
|
14 |
+
|
15 |
+
ops.def("sgmv_cutlass_tmp_size(int num_problems) -> int");
|
16 |
+
ops.impl("sgmv_cutlass_tmp_size", &sgmv_tmp_size);
|
17 |
+
|
18 |
+
ops.def("dispatch_bgmv(Tensor! y, Tensor x, Tensor w_ptr, Tensor indices, "
|
19 |
+
"int layer_indices, float scale) -> ()");
|
20 |
+
ops.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
|
21 |
+
}
|
22 |
+
|
23 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
|
6 |
+
torch::Tensor indicies, int64_t layer_idx, double scale);
|
7 |
+
|
8 |
+
void dispatch_sgmv_cutlass(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
|
9 |
+
torch::Tensor s_start, torch::Tensor s_end,
|
10 |
+
torch::Tensor tmp, int64_t layer_idx);
|
11 |
+
|
12 |
+
void dispatch_sgmv_shrink(torch::Tensor y, torch::Tensor x, torch::Tensor w_ptr,
|
13 |
+
torch::Tensor s_start, torch::Tensor s_end, torch::Tensor tmp, int64_t layer_idx);
|
14 |
+
|
15 |
+
|
16 |
+
int64_t sgmv_tmp_size(int64_t num_problems);
|