Upload custom kernels
Browse files- build.toml +20 -0
- rmsnorm_kernel/rmsnorm.cpp +47 -0
- torch-ext/rmsnorm_kernel/__init__.py +0 -0
- torch-ext/torch_bindings.cpp +11 -0
- torch-ext/torch_bindings.h +5 -0
build.toml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "rmsnorm_kernel"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
src = [
|
6 |
+
"torch-ext/torch_binding.cpp",
|
7 |
+
"torch-ext/torch_binding.h"
|
8 |
+
]
|
9 |
+
|
10 |
+
[kernel.activation]
|
11 |
+
src = [
|
12 |
+
"rmsnorm_kernel/rmsnorm.cpp",
|
13 |
+
]
|
14 |
+
|
15 |
+
depends = [ "torch"]
|
16 |
+
|
17 |
+
# If the kernel is only supported on specific capabilities, set the
|
18 |
+
# cuda-capabilities option:
|
19 |
+
#
|
20 |
+
# cuda-capabilities = [ "9.0", "10.0", "12.0" ]
|
rmsnorm_kernel/rmsnorm.cpp
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <thrust/for_each.h>
|
3 |
+
#include <thrust/iterator/counting_iterator.h>
|
4 |
+
#include <cmath>
|
5 |
+
|
6 |
+
const float EPS = 1e-5f;
|
7 |
+
|
8 |
+
struct RmsnormFunctor {
|
9 |
+
const float* x;
|
10 |
+
const float* gamma;
|
11 |
+
float* out;
|
12 |
+
int hidden_dim;
|
13 |
+
|
14 |
+
RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_)
|
15 |
+
: x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {}
|
16 |
+
|
17 |
+
__device__
|
18 |
+
void operator()(int row_idx) {
|
19 |
+
const float* row_x = x + row_idx * hidden_dim;
|
20 |
+
float* row_out = out + row_idx * hidden_dim;
|
21 |
+
|
22 |
+
float sum_sq = 0.0f;
|
23 |
+
for (int i = 0; i < hidden_dim; ++i)
|
24 |
+
sum_sq += row_x[i] * row_x[i];
|
25 |
+
|
26 |
+
float rms = sqrtf(sum_sq / hidden_dim + EPS);
|
27 |
+
|
28 |
+
for (int i = 0; i < hidden_dim; ++i)
|
29 |
+
row_out[i] = (row_x[i] / rms) * gamma[i];
|
30 |
+
}
|
31 |
+
};
|
32 |
+
|
33 |
+
void rmsnorm_forward(torch::Tensor x, torch::Tensor gamma, torch::Tensor out) {
|
34 |
+
int B = x.size(0), S = x.size(1), H = x.size(2);
|
35 |
+
int rows = B * S;
|
36 |
+
|
37 |
+
const float* x_ptr = x.data_ptr<float>();
|
38 |
+
const float* gamma_ptr = gamma.data_ptr<float>();
|
39 |
+
float* out_ptr = out.data_ptr<float>();
|
40 |
+
|
41 |
+
thrust::counting_iterator<int> iter(0);
|
42 |
+
thrust::for_each(
|
43 |
+
thrust::device,
|
44 |
+
iter, iter + rows,
|
45 |
+
RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H)
|
46 |
+
);
|
47 |
+
}
|
torch-ext/rmsnorm_kernel/__init__.py
ADDED
File without changes
|
torch-ext/torch_bindings.cpp
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_bindings.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("rmsnorm_forward(Tensor! out, Tensor input, Tensor gamma) -> ()");
|
8 |
+
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
|
9 |
+
}
|
10 |
+
|
11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_bindings.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void rmsnorm_forward(torch::Tensor &out, torch::Tensor const &input, torch::Tensor const &gamma);
|