Upload custom kernels
Browse files- build.toml +15 -0
- flake.nix +13 -0
- rmsnorm_kernel/rmsnorm.cu +163 -0
- torch-ext/rmsnorm_kernel/__init__.py +21 -0
- torch-ext/torch_binding.cpp +11 -0
- torch-ext/torch_binding.h +5 -0
build.toml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[general]
|
2 |
+
name = "rmsnorm_kernel"
|
3 |
+
|
4 |
+
[torch]
|
5 |
+
src = [
|
6 |
+
"torch-ext/torch_binding.cpp",
|
7 |
+
"torch-ext/torch_binding.h"
|
8 |
+
]
|
9 |
+
|
10 |
+
[kernel.rmsnorm_kernel]
|
11 |
+
src = [
|
12 |
+
"rmsnorm_kernel/rmsnorm.cu",
|
13 |
+
]
|
14 |
+
depends = [ "torch"]
|
15 |
+
cuda-capabilities = [ "12.3" ]
|
flake.nix
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Torch kernel extension";
|
3 |
+
|
4 |
+
inputs = {
|
5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
6 |
+
};
|
7 |
+
|
8 |
+
outputs = { self, kernel-builder, }:
|
9 |
+
kernel-builder.lib.genFlakeOutputs {
|
10 |
+
path = ./.;
|
11 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
12 |
+
};
|
13 |
+
}
|
rmsnorm_kernel/rmsnorm.cu
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <thrust/execution_policy.h>
|
3 |
+
#include <thrust/for_each.h>
|
4 |
+
#include <thrust/iterator/counting_iterator.h>
|
5 |
+
#include <cmath>
|
6 |
+
#include <thrust/device_vector.h>
|
7 |
+
#include <thrust/copy.h>
|
8 |
+
#include <iostream>
|
9 |
+
#include <iomanip> // For formatting output
|
10 |
+
|
11 |
+
const float EPS = 1e-5f;
|
12 |
+
|
13 |
+
// CPU implementation of RMSNorm
|
14 |
+
torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) {
|
15 |
+
int B = x.size(0), S = x.size(1), H = x.size(2);
|
16 |
+
auto out = torch::empty_like(x);
|
17 |
+
|
18 |
+
auto x_accessor = x.accessor<float, 3>();
|
19 |
+
auto gamma_accessor = gamma.accessor<float, 1>();
|
20 |
+
auto out_accessor = out.accessor<float, 3>();
|
21 |
+
|
22 |
+
// Process each row
|
23 |
+
for (int b = 0; b < B; ++b) {
|
24 |
+
for (int s = 0; s < S; ++s) {
|
25 |
+
// Calculate root mean square
|
26 |
+
float sum_sq = 0.0f;
|
27 |
+
for (int h = 0; h < H; ++h) {
|
28 |
+
float val = x_accessor[b][s][h];
|
29 |
+
sum_sq += val * val;
|
30 |
+
}
|
31 |
+
float rms = std::sqrt(sum_sq / H + EPS);
|
32 |
+
|
33 |
+
// Normalize and scale
|
34 |
+
for (int h = 0; h < H; ++h) {
|
35 |
+
out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h];
|
36 |
+
}
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
return out;
|
41 |
+
}
|
42 |
+
|
43 |
+
struct RmsnormFunctor {
|
44 |
+
const float* x;
|
45 |
+
const float* gamma;
|
46 |
+
float* out;
|
47 |
+
int hidden_dim;
|
48 |
+
|
49 |
+
RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_)
|
50 |
+
: x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {}
|
51 |
+
|
52 |
+
__device__
|
53 |
+
void operator()(int row_idx) {
|
54 |
+
const float* row_x = x + row_idx * hidden_dim;
|
55 |
+
float* row_out = out + row_idx * hidden_dim;
|
56 |
+
|
57 |
+
float sum_sq = 0.0f;
|
58 |
+
for (int i = 0; i < hidden_dim; ++i)
|
59 |
+
sum_sq += row_x[i] * row_x[i];
|
60 |
+
|
61 |
+
float rms = sqrtf(sum_sq / hidden_dim + EPS);
|
62 |
+
|
63 |
+
for (int i = 0; i < hidden_dim; ++i)
|
64 |
+
row_out[i] = (row_x[i] / rms) * gamma[i];
|
65 |
+
}
|
66 |
+
};
|
67 |
+
|
68 |
+
torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor gamma) {
|
69 |
+
int B = x.size(0), S = x.size(1), H = x.size(2);
|
70 |
+
int rows = B * S;
|
71 |
+
|
72 |
+
// Create output tensor with same shape as input
|
73 |
+
auto out = torch::empty_like(x);
|
74 |
+
|
75 |
+
const float* x_ptr = x.data_ptr<float>();
|
76 |
+
const float* gamma_ptr = gamma.data_ptr<float>();
|
77 |
+
float* out_ptr = out.data_ptr<float>();
|
78 |
+
|
79 |
+
thrust::counting_iterator<int> iter(0);
|
80 |
+
thrust::for_each(
|
81 |
+
thrust::device,
|
82 |
+
iter, iter + rows,
|
83 |
+
RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H)
|
84 |
+
);
|
85 |
+
|
86 |
+
return out;
|
87 |
+
}
|
88 |
+
|
89 |
+
// int main() {
|
90 |
+
// int B = 2, S = 2, H = 4;
|
91 |
+
|
92 |
+
// // Create tensors directly on CPU first
|
93 |
+
// auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
|
94 |
+
|
95 |
+
// // Initialize with CPU data
|
96 |
+
// torch::Tensor x_cpu = torch::tensor({
|
97 |
+
// {
|
98 |
+
// {1.0f, 2.0f, 3.0f, 4.0f},
|
99 |
+
// {5.0f, 6.0f, 7.0f, 8.0f}
|
100 |
+
// },
|
101 |
+
// {
|
102 |
+
// {2.0f, 2.0f, 2.0f, 2.0f},
|
103 |
+
// {9.0f, 10.0f, 11.0f, 12.0f}
|
104 |
+
// }
|
105 |
+
// }, options_cpu);
|
106 |
+
|
107 |
+
// torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu);
|
108 |
+
|
109 |
+
// // Run CPU version
|
110 |
+
// std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl;
|
111 |
+
// torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu);
|
112 |
+
// auto cpu_accessor = out_cpu.accessor<float, 3>();
|
113 |
+
|
114 |
+
// for (int b = 0; b < B; ++b) {
|
115 |
+
// for (int s = 0; s < S; ++s) {
|
116 |
+
// std::cout << "Row " << (b * S + s) << ": ";
|
117 |
+
// for (int h = 0; h < H; ++h) {
|
118 |
+
// std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " ";
|
119 |
+
// }
|
120 |
+
// std::cout << "\n";
|
121 |
+
// }
|
122 |
+
// }
|
123 |
+
|
124 |
+
// // Move tensors to CUDA for GPU version
|
125 |
+
// auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
126 |
+
// torch::Tensor x_cuda = x_cpu.to(torch::kCUDA);
|
127 |
+
// torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA);
|
128 |
+
|
129 |
+
// // Call the CUDA kernel wrapper
|
130 |
+
// std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl;
|
131 |
+
// torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda);
|
132 |
+
|
133 |
+
// // Copy result back to CPU and print
|
134 |
+
// auto gpu_result_on_cpu = out_cuda.cpu();
|
135 |
+
// auto gpu_accessor = gpu_result_on_cpu.accessor<float, 3>();
|
136 |
+
|
137 |
+
// for (int b = 0; b < B; ++b) {
|
138 |
+
// for (int s = 0; s < S; ++s) {
|
139 |
+
// std::cout << "Row " << (b * S + s) << ": ";
|
140 |
+
// for (int h = 0; h < H; ++h) {
|
141 |
+
// std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " ";
|
142 |
+
// }
|
143 |
+
// std::cout << "\n";
|
144 |
+
// }
|
145 |
+
// }
|
146 |
+
|
147 |
+
// // Check if results match
|
148 |
+
// std::cout << "\n===== COMPARISON =====" << std::endl;
|
149 |
+
// float max_diff = 0.0f;
|
150 |
+
// for (int b = 0; b < B; ++b) {
|
151 |
+
// for (int s = 0; s < S; ++s) {
|
152 |
+
// for (int h = 0; h < H; ++h) {
|
153 |
+
// float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]);
|
154 |
+
// max_diff = std::max(max_diff, diff);
|
155 |
+
// }
|
156 |
+
// }
|
157 |
+
// }
|
158 |
+
// std::cout << "Maximum difference between CPU and GPU results: "
|
159 |
+
// << std::scientific << max_diff << std::endl;
|
160 |
+
// std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl;
|
161 |
+
|
162 |
+
// return 0;
|
163 |
+
// }
|
torch-ext/rmsnorm_kernel/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from ._ops import ops
|
5 |
+
|
6 |
+
|
7 |
+
class LlamaRMSNorm(nn.Module):
|
8 |
+
weight: torch.Tensor
|
9 |
+
variance_epsilon: float
|
10 |
+
|
11 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
12 |
+
return ops.rmsnorm_forward(
|
13 |
+
hidden_states,
|
14 |
+
self.weight,
|
15 |
+
bias=None,
|
16 |
+
residual=None,
|
17 |
+
eps=self.variance_epsilon,
|
18 |
+
dropout_p=0.0,
|
19 |
+
prenorm=False,
|
20 |
+
residual_in_fp32=False,
|
21 |
+
)
|
torch-ext/torch_binding.cpp
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/library.h>
|
2 |
+
|
3 |
+
#include "registration.h"
|
4 |
+
#include "torch_binding.h"
|
5 |
+
|
6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
+
ops.def("rmsnorm_forward(Tensor input, Tensor gamma) -> ()");
|
8 |
+
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward);
|
9 |
+
}
|
10 |
+
|
11 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#pragma once
|
2 |
+
|
3 |
+
#include <torch/torch.h>
|
4 |
+
|
5 |
+
void rmsnorm_forward(torch::Tensor const &input, torch::Tensor const &gamma);
|