medmekk HF Staff commited on
Commit
ac27409
·
verified ·
1 Parent(s): 3cfa33d

Upload custom kernels

Browse files
build.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "rmsnorm_kernel"
3
+
4
+ [torch]
5
+ src = [
6
+ "torch-ext/torch_binding.cpp",
7
+ "torch-ext/torch_binding.h"
8
+ ]
9
+
10
+ [kernel.activation]
11
+ src = [
12
+ "rmsnorm_kernel/rmsnorm.cpp",
13
+ ]
14
+
15
+ depends = [ "torch"]
16
+
17
+ # If the kernel is only supported on specific capabilities, set the
18
+ # cuda-capabilities option:
19
+ #
20
+ # cuda-capabilities = [ "9.0", "10.0", "12.0" ]
rmsnorm_kernel/rmsnorm.cpp ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <thrust/for_each.h>
3
+ #include <thrust/iterator/counting_iterator.h>
4
+ #include <cmath>
5
+
6
+ const float EPS = 1e-5f;
7
+
8
+ struct RmsnormFunctor {
9
+ const float* x;
10
+ const float* gamma;
11
+ float* out;
12
+ int hidden_dim;
13
+
14
+ RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_)
15
+ : x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {}
16
+
17
+ __device__
18
+ void operator()(int row_idx) {
19
+ const float* row_x = x + row_idx * hidden_dim;
20
+ float* row_out = out + row_idx * hidden_dim;
21
+
22
+ float sum_sq = 0.0f;
23
+ for (int i = 0; i < hidden_dim; ++i)
24
+ sum_sq += row_x[i] * row_x[i];
25
+
26
+ float rms = sqrtf(sum_sq / hidden_dim + EPS);
27
+
28
+ for (int i = 0; i < hidden_dim; ++i)
29
+ row_out[i] = (row_x[i] / rms) * gamma[i];
30
+ }
31
+ };
32
+
33
+ void rmsnorm_forward(torch::Tensor x, torch::Tensor gamma, torch::Tensor out) {
34
+ int B = x.size(0), S = x.size(1), H = x.size(2);
35
+ int rows = B * S;
36
+
37
+ const float* x_ptr = x.data_ptr<float>();
38
+ const float* gamma_ptr = gamma.data_ptr<float>();
39
+ float* out_ptr = out.data_ptr<float>();
40
+
41
+ thrust::counting_iterator<int> iter(0);
42
+ thrust::for_each(
43
+ thrust::device,
44
+ iter, iter + rows,
45
+ RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H)
46
+ );
47
+ }
torch-ext/rmsnorm_kernel/__init__.py ADDED
File without changes
torch-ext/torch_bindings.cpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_bindings.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("rmsnorm_forward(Tensor! out, Tensor input, Tensor gamma) -> ()");
8
+ ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
9
+ }
10
+
11
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_bindings.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void rmsnorm_forward(torch::Tensor &out, torch::Tensor const &input, torch::Tensor const &gamma);