medmekk HF Staff commited on
Commit
4303459
·
verified ·
1 Parent(s): 5c7601b

Upload custom kernels

Browse files
build.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "rmsnorm_kernel"
3
+
4
+ [torch]
5
+ src = [
6
+ "torch-ext/torch_binding.cpp",
7
+ "torch-ext/torch_binding.h"
8
+ ]
9
+
10
+ [kernel.rmsnorm_kernel]
11
+ src = [
12
+ "rmsnorm_kernel/rmsnorm.cu",
13
+ ]
14
+ depends = [ "torch"]
15
+ cuda-capabilities = [ "12.3" ]
flake.nix ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Torch kernel extension";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs = { self, kernel-builder, }:
9
+ kernel-builder.lib.genFlakeOutputs {
10
+ path = ./.;
11
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
12
+ };
13
+ }
rmsnorm_kernel/rmsnorm.cu ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <thrust/execution_policy.h>
3
+ #include <thrust/for_each.h>
4
+ #include <thrust/iterator/counting_iterator.h>
5
+ #include <cmath>
6
+ #include <thrust/device_vector.h>
7
+ #include <thrust/copy.h>
8
+ #include <iostream>
9
+ #include <iomanip> // For formatting output
10
+
11
+ const float EPS = 1e-5f;
12
+
13
+ // CPU implementation of RMSNorm
14
+ torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) {
15
+ int B = x.size(0), S = x.size(1), H = x.size(2);
16
+ auto out = torch::empty_like(x);
17
+
18
+ auto x_accessor = x.accessor<float, 3>();
19
+ auto gamma_accessor = gamma.accessor<float, 1>();
20
+ auto out_accessor = out.accessor<float, 3>();
21
+
22
+ // Process each row
23
+ for (int b = 0; b < B; ++b) {
24
+ for (int s = 0; s < S; ++s) {
25
+ // Calculate root mean square
26
+ float sum_sq = 0.0f;
27
+ for (int h = 0; h < H; ++h) {
28
+ float val = x_accessor[b][s][h];
29
+ sum_sq += val * val;
30
+ }
31
+ float rms = std::sqrt(sum_sq / H + EPS);
32
+
33
+ // Normalize and scale
34
+ for (int h = 0; h < H; ++h) {
35
+ out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h];
36
+ }
37
+ }
38
+ }
39
+
40
+ return out;
41
+ }
42
+
43
+ struct RmsnormFunctor {
44
+ const float* x;
45
+ const float* gamma;
46
+ float* out;
47
+ int hidden_dim;
48
+
49
+ RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_)
50
+ : x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {}
51
+
52
+ __device__
53
+ void operator()(int row_idx) {
54
+ const float* row_x = x + row_idx * hidden_dim;
55
+ float* row_out = out + row_idx * hidden_dim;
56
+
57
+ float sum_sq = 0.0f;
58
+ for (int i = 0; i < hidden_dim; ++i)
59
+ sum_sq += row_x[i] * row_x[i];
60
+
61
+ float rms = sqrtf(sum_sq / hidden_dim + EPS);
62
+
63
+ for (int i = 0; i < hidden_dim; ++i)
64
+ row_out[i] = (row_x[i] / rms) * gamma[i];
65
+ }
66
+ };
67
+
68
+ torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor gamma) {
69
+ int B = x.size(0), S = x.size(1), H = x.size(2);
70
+ int rows = B * S;
71
+
72
+ // Create output tensor with same shape as input
73
+ auto out = torch::empty_like(x);
74
+
75
+ const float* x_ptr = x.data_ptr<float>();
76
+ const float* gamma_ptr = gamma.data_ptr<float>();
77
+ float* out_ptr = out.data_ptr<float>();
78
+
79
+ thrust::counting_iterator<int> iter(0);
80
+ thrust::for_each(
81
+ thrust::device,
82
+ iter, iter + rows,
83
+ RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H)
84
+ );
85
+
86
+ return out;
87
+ }
88
+
89
+ // int main() {
90
+ // int B = 2, S = 2, H = 4;
91
+
92
+ // // Create tensors directly on CPU first
93
+ // auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
94
+
95
+ // // Initialize with CPU data
96
+ // torch::Tensor x_cpu = torch::tensor({
97
+ // {
98
+ // {1.0f, 2.0f, 3.0f, 4.0f},
99
+ // {5.0f, 6.0f, 7.0f, 8.0f}
100
+ // },
101
+ // {
102
+ // {2.0f, 2.0f, 2.0f, 2.0f},
103
+ // {9.0f, 10.0f, 11.0f, 12.0f}
104
+ // }
105
+ // }, options_cpu);
106
+
107
+ // torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu);
108
+
109
+ // // Run CPU version
110
+ // std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl;
111
+ // torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu);
112
+ // auto cpu_accessor = out_cpu.accessor<float, 3>();
113
+
114
+ // for (int b = 0; b < B; ++b) {
115
+ // for (int s = 0; s < S; ++s) {
116
+ // std::cout << "Row " << (b * S + s) << ": ";
117
+ // for (int h = 0; h < H; ++h) {
118
+ // std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " ";
119
+ // }
120
+ // std::cout << "\n";
121
+ // }
122
+ // }
123
+
124
+ // // Move tensors to CUDA for GPU version
125
+ // auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
126
+ // torch::Tensor x_cuda = x_cpu.to(torch::kCUDA);
127
+ // torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA);
128
+
129
+ // // Call the CUDA kernel wrapper
130
+ // std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl;
131
+ // torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda);
132
+
133
+ // // Copy result back to CPU and print
134
+ // auto gpu_result_on_cpu = out_cuda.cpu();
135
+ // auto gpu_accessor = gpu_result_on_cpu.accessor<float, 3>();
136
+
137
+ // for (int b = 0; b < B; ++b) {
138
+ // for (int s = 0; s < S; ++s) {
139
+ // std::cout << "Row " << (b * S + s) << ": ";
140
+ // for (int h = 0; h < H; ++h) {
141
+ // std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " ";
142
+ // }
143
+ // std::cout << "\n";
144
+ // }
145
+ // }
146
+
147
+ // // Check if results match
148
+ // std::cout << "\n===== COMPARISON =====" << std::endl;
149
+ // float max_diff = 0.0f;
150
+ // for (int b = 0; b < B; ++b) {
151
+ // for (int s = 0; s < S; ++s) {
152
+ // for (int h = 0; h < H; ++h) {
153
+ // float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]);
154
+ // max_diff = std::max(max_diff, diff);
155
+ // }
156
+ // }
157
+ // }
158
+ // std::cout << "Maximum difference between CPU and GPU results: "
159
+ // << std::scientific << max_diff << std::endl;
160
+ // std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl;
161
+
162
+ // return 0;
163
+ // }
torch-ext/rmsnorm_kernel/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LlamaRMSNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.rmsnorm_forward(
13
+ hidden_states,
14
+ self.weight,
15
+ bias=None,
16
+ residual=None,
17
+ eps=self.variance_epsilon,
18
+ dropout_p=0.0,
19
+ prenorm=False,
20
+ residual_in_fp32=False,
21
+ )
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ ops.def("rmsnorm_forward(Tensor input, Tensor gamma) -> ()");
8
+ ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
9
+ }
10
+
11
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void rmsnorm_forward(torch::Tensor const &input, torch::Tensor const &gamma);