Commit
·
f3b99fb
1
Parent(s):
d14fd4d
feat(rms-norm): Impl fused RMSNorm
Browse files- README.md +1 -0
- activation/block_reduce.h +20 -0
- activation/{activation_kernels.cu → poly_norm.cu} +9 -20
- activation/rms_norm.cu +168 -0
- build.toml +3 -1
- build/flake.lock +0 -168
- build/flake.nix +0 -11
- build/torch26-cxx11-rocm62-x86_64-linux/activation/__init__.py +9 -0
- build/{torch27-cxx11-rocm63-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch26-cxx11-rocm62-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so} +2 -2
- build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py +3 -3
- build/torch26-cxx11-rocm62-x86_64-linux/activation/layers.py +17 -3
- build/torch26-cxx11-rocm62-x86_64-linux/activation/rms_norm.py +34 -0
- build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +9 -0
- build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so} +2 -2
- build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +17 -3
- build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +34 -0
- tests/conftest.py +1 -1
- tests/kernels/{test_activation.py → test_poly_norm.py} +0 -0
- tests/kernels/{test_perf.py → test_poly_norm_perf.py} +1 -1
- tests/kernels/test_rms_norm.py +72 -0
- torch-ext/activation/__init__.py +9 -0
- torch-ext/activation/layers.py +17 -3
- torch-ext/activation/rms_norm.py +34 -0
- torch-ext/torch_binding.cpp +8 -0
- torch-ext/torch_binding.h +3 -0
README.md
CHANGED
@@ -9,6 +9,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
|
|
9 |
|
10 |
- Currently implemented
|
11 |
- [PolyNorm](https://arxiv.org/html/2411.03884v1)
|
|
|
12 |
|
13 |
## Usage
|
14 |
|
|
|
9 |
|
10 |
- Currently implemented
|
11 |
- [PolyNorm](https://arxiv.org/html/2411.03884v1)
|
12 |
+
- [RMSNorm](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
|
13 |
|
14 |
## Usage
|
15 |
|
activation/block_reduce.h
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
namespace motif {
|
2 |
+
|
3 |
+
template <typename acc_t, int BLOCK_SIZE>
|
4 |
+
__device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) {
|
5 |
+
// TODO: Optimize with warp-level primitives
|
6 |
+
__syncthreads();
|
7 |
+
|
8 |
+
shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
|
9 |
+
__syncthreads();
|
10 |
+
for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
|
11 |
+
if (threadIdx.x < stride) {
|
12 |
+
shared[threadIdx.x] += shared[threadIdx.x + stride];
|
13 |
+
}
|
14 |
+
__syncthreads();
|
15 |
+
}
|
16 |
+
|
17 |
+
return shared[0];
|
18 |
+
}
|
19 |
+
|
20 |
+
} // motif
|
activation/{activation_kernels.cu → poly_norm.cu}
RENAMED
@@ -9,26 +9,10 @@
|
|
9 |
#include "dispatch_utils.h"
|
10 |
#include "assert_utils.h"
|
11 |
#include "atomic_utils.h"
|
|
|
12 |
|
13 |
namespace motif {
|
14 |
|
15 |
-
template <typename acc_t, int BLOCK_SIZE>
|
16 |
-
__device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) {
|
17 |
-
// TODO: Optimize with warp-level primitives
|
18 |
-
__syncthreads();
|
19 |
-
|
20 |
-
shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
|
21 |
-
__syncthreads();
|
22 |
-
for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
|
23 |
-
if (threadIdx.x < stride) {
|
24 |
-
shared[threadIdx.x] += shared[threadIdx.x + stride];
|
25 |
-
}
|
26 |
-
__syncthreads();
|
27 |
-
}
|
28 |
-
|
29 |
-
return shared[0];
|
30 |
-
}
|
31 |
-
|
32 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
33 |
__global__ void poly_norm_kernel(
|
34 |
scalar_t* __restrict__ out, // [..., d]
|
@@ -251,7 +235,12 @@ void poly_norm_backward(
|
|
251 |
}
|
252 |
);
|
253 |
|
254 |
-
|
255 |
-
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
257 |
}
|
|
|
9 |
#include "dispatch_utils.h"
|
10 |
#include "assert_utils.h"
|
11 |
#include "atomic_utils.h"
|
12 |
+
#include "block_reduce.h"
|
13 |
|
14 |
namespace motif {
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
17 |
__global__ void poly_norm_kernel(
|
18 |
scalar_t* __restrict__ out, // [..., d]
|
|
|
235 |
}
|
236 |
);
|
237 |
|
238 |
+
if (bias_grad.defined()) {
|
239 |
+
at::sum_out(bias_grad, output_grad);
|
240 |
+
bias_grad.resize_({1});
|
241 |
+
}
|
242 |
+
|
243 |
+
if (weight_grad.defined()) {
|
244 |
+
at::sum_out(weight_grad, temp_weight_grad, {0});
|
245 |
+
}
|
246 |
}
|
activation/rms_norm.cu
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <ATen/Functions.h>
|
3 |
+
#include <torch/all.h>
|
4 |
+
#include <c10/cuda/CUDAGuard.h>
|
5 |
+
|
6 |
+
#include <cmath>
|
7 |
+
|
8 |
+
#include "cuda_compat.h"
|
9 |
+
#include "dispatch_utils.h"
|
10 |
+
#include "assert_utils.h"
|
11 |
+
#include "atomic_utils.h"
|
12 |
+
#include "block_reduce.h"
|
13 |
+
|
14 |
+
namespace motif {
|
15 |
+
|
16 |
+
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
17 |
+
__global__ void rms_norm_kernel(
|
18 |
+
scalar_t* __restrict__ out, // [..., d]
|
19 |
+
const scalar_t* __restrict__ input, // [..., d]
|
20 |
+
const scalar_t* __restrict__ weight, // [d]
|
21 |
+
const float eps,
|
22 |
+
const int d
|
23 |
+
) {
|
24 |
+
|
25 |
+
const int64_t token_idx = blockIdx.x;
|
26 |
+
const int64_t vec_idx = threadIdx.x;
|
27 |
+
acc_t sum_square = 0.0f;
|
28 |
+
|
29 |
+
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
30 |
+
acc_t x = input[token_idx * d + idx];
|
31 |
+
sum_square += x * x;
|
32 |
+
}
|
33 |
+
|
34 |
+
__shared__ acc_t shared[BLOCK_SIZE];
|
35 |
+
|
36 |
+
acc_t variance =
|
37 |
+
_block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
|
38 |
+
acc_t scale = rsqrt(variance + eps);
|
39 |
+
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
40 |
+
acc_t x = input[token_idx * d + idx];
|
41 |
+
acc_t w = weight[idx];
|
42 |
+
out[token_idx * d + idx] = w * x * scale;
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
47 |
+
__global__ void rms_norm_backward_kernel(
|
48 |
+
scalar_t* __restrict__ input_grad, // [..., d]
|
49 |
+
acc_t* __restrict__ temp_weight_grad, // [..., d]
|
50 |
+
const scalar_t* __restrict__ output_grad, // [..., d]
|
51 |
+
const scalar_t* __restrict__ input, // [..., d]
|
52 |
+
const scalar_t* __restrict__ weight, // [d]
|
53 |
+
const float eps,
|
54 |
+
const int d
|
55 |
+
) {
|
56 |
+
const int64_t token_idx = blockIdx.x;
|
57 |
+
const int64_t vec_idx = threadIdx.x;
|
58 |
+
acc_t d_sum = 0.0f;
|
59 |
+
acc_t sum_square = 0.0f;
|
60 |
+
|
61 |
+
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
62 |
+
acc_t x = input[token_idx * d + idx];
|
63 |
+
acc_t dy = output_grad[token_idx * d + idx];
|
64 |
+
acc_t w = weight[idx];
|
65 |
+
d_sum += dy * x * w;
|
66 |
+
sum_square += x * x;
|
67 |
+
}
|
68 |
+
|
69 |
+
__shared__ acc_t shared[BLOCK_SIZE];
|
70 |
+
|
71 |
+
d_sum = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, d_sum, d);
|
72 |
+
acc_t variance =
|
73 |
+
_block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
|
74 |
+
acc_t scale = rsqrt(variance + eps);
|
75 |
+
acc_t scale_cubed = scale * scale * scale;
|
76 |
+
acc_t dxx = d_sum * scale_cubed / d;
|
77 |
+
|
78 |
+
for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
|
79 |
+
acc_t x = input[token_idx * d + idx];
|
80 |
+
acc_t dy = output_grad[token_idx * d + idx];
|
81 |
+
acc_t w = weight[idx];
|
82 |
+
|
83 |
+
input_grad[token_idx * d + idx] =
|
84 |
+
scale * dy * w - dxx * x;
|
85 |
+
|
86 |
+
if (temp_weight_grad) {
|
87 |
+
temp_weight_grad[token_idx * d + idx] = dy * x * scale;
|
88 |
+
}
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
} // namespace motif
|
93 |
+
|
94 |
+
|
95 |
+
void rms_norm(torch::Tensor& out, // [..., d]
|
96 |
+
const torch::Tensor& input, // [..., d]
|
97 |
+
const torch::Tensor& weight, // [d]
|
98 |
+
double eps)
|
99 |
+
{
|
100 |
+
AssertTensorShapeEqual(input, out, "input", "out");
|
101 |
+
AssertTensorNotNull(weight, "weight");
|
102 |
+
// TODO shape check
|
103 |
+
|
104 |
+
constexpr int BLOCK_SIZE = 256;
|
105 |
+
|
106 |
+
int d = input.size(-1);
|
107 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
108 |
+
dim3 grid(num_tokens);
|
109 |
+
dim3 block(BLOCK_SIZE);
|
110 |
+
|
111 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
112 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
113 |
+
MOTIF_DISPATCH_FLOATING_TYPES(
|
114 |
+
input.scalar_type(), "rms_norm_kernel", [&] {
|
115 |
+
motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
|
116 |
+
<<<grid, block, 0, stream>>>(
|
117 |
+
out.data_ptr<scalar_t>(),
|
118 |
+
input.data_ptr<scalar_t>(),
|
119 |
+
weight.data_ptr<scalar_t>(),
|
120 |
+
eps, d);
|
121 |
+
}
|
122 |
+
);
|
123 |
+
}
|
124 |
+
|
125 |
+
void rms_norm_backward(
|
126 |
+
torch::Tensor& input_grad, // [..., d]
|
127 |
+
torch::Tensor& weight_grad, // [..., d]
|
128 |
+
const torch::Tensor& output_grad, // [d]
|
129 |
+
const torch::Tensor& input, // [d]
|
130 |
+
const torch::Tensor& weight, // [d]
|
131 |
+
double eps) {
|
132 |
+
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
133 |
+
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
134 |
+
AssertTensorNotNull(weight, "weight");
|
135 |
+
// TODO shape check
|
136 |
+
// weight_grad, input_grad can be nullable
|
137 |
+
|
138 |
+
constexpr int BLOCK_SIZE = 256;
|
139 |
+
|
140 |
+
int d = input.size(-1);
|
141 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
142 |
+
dim3 grid(num_tokens);
|
143 |
+
dim3 block(BLOCK_SIZE);
|
144 |
+
|
145 |
+
torch::Tensor temp_weight_grad =
|
146 |
+
torch::empty({num_tokens, d},
|
147 |
+
input.options().dtype(torch::kFloat));
|
148 |
+
|
149 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
150 |
+
|
151 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
152 |
+
MOTIF_DISPATCH_FLOATING_TYPES(
|
153 |
+
input.scalar_type(), "rms_norm_backward_kernel", [&] {
|
154 |
+
motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
|
155 |
+
<<<grid, block, 0, stream>>>(
|
156 |
+
input_grad.data_ptr<scalar_t>(),
|
157 |
+
temp_weight_grad.data_ptr<float>(),
|
158 |
+
output_grad.data_ptr<scalar_t>(),
|
159 |
+
input.data_ptr<scalar_t>(),
|
160 |
+
weight.data_ptr<scalar_t>(),
|
161 |
+
eps, d);
|
162 |
+
}
|
163 |
+
);
|
164 |
+
|
165 |
+
if (weight_grad.defined()) {
|
166 |
+
at::sum_out(weight_grad, temp_weight_grad, {0});
|
167 |
+
}
|
168 |
+
}
|
build.toml
CHANGED
@@ -12,8 +12,10 @@ src = [
|
|
12 |
backend = "rocm"
|
13 |
rocm-archs = [ "gfx90a" ]
|
14 |
src = [
|
15 |
-
"activation/
|
|
|
16 |
"activation/cuda_compat.h",
|
|
|
17 |
"activation/dispatch_utils.h",
|
18 |
"activation/assert_utils.h",
|
19 |
"activation/atomic_utils.h",
|
|
|
12 |
backend = "rocm"
|
13 |
rocm-archs = [ "gfx90a" ]
|
14 |
src = [
|
15 |
+
"activation/poly_norm.cu",
|
16 |
+
"activation/rms_norm.cu",
|
17 |
"activation/cuda_compat.h",
|
18 |
+
"activation/block_reduce.h",
|
19 |
"activation/dispatch_utils.h",
|
20 |
"activation/assert_utils.h",
|
21 |
"activation/atomic_utils.h",
|
build/flake.lock
DELETED
@@ -1,168 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"nodes": {
|
3 |
-
"flake-compat": {
|
4 |
-
"locked": {
|
5 |
-
"lastModified": 1747046372,
|
6 |
-
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
-
"owner": "edolstra",
|
8 |
-
"repo": "flake-compat",
|
9 |
-
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
-
"type": "github"
|
11 |
-
},
|
12 |
-
"original": {
|
13 |
-
"owner": "edolstra",
|
14 |
-
"repo": "flake-compat",
|
15 |
-
"type": "github"
|
16 |
-
}
|
17 |
-
},
|
18 |
-
"flake-compat_2": {
|
19 |
-
"locked": {
|
20 |
-
"lastModified": 1733328505,
|
21 |
-
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
22 |
-
"owner": "edolstra",
|
23 |
-
"repo": "flake-compat",
|
24 |
-
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
25 |
-
"type": "github"
|
26 |
-
},
|
27 |
-
"original": {
|
28 |
-
"owner": "edolstra",
|
29 |
-
"repo": "flake-compat",
|
30 |
-
"type": "github"
|
31 |
-
}
|
32 |
-
},
|
33 |
-
"flake-utils": {
|
34 |
-
"inputs": {
|
35 |
-
"systems": "systems"
|
36 |
-
},
|
37 |
-
"locked": {
|
38 |
-
"lastModified": 1731533236,
|
39 |
-
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
40 |
-
"owner": "numtide",
|
41 |
-
"repo": "flake-utils",
|
42 |
-
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
43 |
-
"type": "github"
|
44 |
-
},
|
45 |
-
"original": {
|
46 |
-
"owner": "numtide",
|
47 |
-
"repo": "flake-utils",
|
48 |
-
"type": "github"
|
49 |
-
}
|
50 |
-
},
|
51 |
-
"flake-utils_2": {
|
52 |
-
"inputs": {
|
53 |
-
"systems": "systems_2"
|
54 |
-
},
|
55 |
-
"locked": {
|
56 |
-
"lastModified": 1731533236,
|
57 |
-
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
-
"owner": "numtide",
|
59 |
-
"repo": "flake-utils",
|
60 |
-
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
-
"type": "github"
|
62 |
-
},
|
63 |
-
"original": {
|
64 |
-
"owner": "numtide",
|
65 |
-
"repo": "flake-utils",
|
66 |
-
"type": "github"
|
67 |
-
}
|
68 |
-
},
|
69 |
-
"hf-nix": {
|
70 |
-
"inputs": {
|
71 |
-
"flake-compat": "flake-compat_2",
|
72 |
-
"flake-utils": "flake-utils_2",
|
73 |
-
"nixpkgs": "nixpkgs"
|
74 |
-
},
|
75 |
-
"locked": {
|
76 |
-
"lastModified": 1747919133,
|
77 |
-
"narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
|
78 |
-
"owner": "huggingface",
|
79 |
-
"repo": "hf-nix",
|
80 |
-
"rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
|
81 |
-
"type": "github"
|
82 |
-
},
|
83 |
-
"original": {
|
84 |
-
"owner": "huggingface",
|
85 |
-
"repo": "hf-nix",
|
86 |
-
"type": "github"
|
87 |
-
}
|
88 |
-
},
|
89 |
-
"kernel-builder": {
|
90 |
-
"inputs": {
|
91 |
-
"flake-compat": "flake-compat",
|
92 |
-
"flake-utils": "flake-utils",
|
93 |
-
"hf-nix": "hf-nix",
|
94 |
-
"nixpkgs": [
|
95 |
-
"kernel-builder",
|
96 |
-
"hf-nix",
|
97 |
-
"nixpkgs"
|
98 |
-
]
|
99 |
-
},
|
100 |
-
"locked": {
|
101 |
-
"lastModified": 1748620233,
|
102 |
-
"narHash": "sha256-VULm9HgGXvo3pyfsPy3SOhoqgkuqbGSaSemvzNUbdIU=",
|
103 |
-
"owner": "huggingface",
|
104 |
-
"repo": "kernel-builder",
|
105 |
-
"rev": "da3340e5b3cbb6086600420f4814b033395788d1",
|
106 |
-
"type": "github"
|
107 |
-
},
|
108 |
-
"original": {
|
109 |
-
"owner": "huggingface",
|
110 |
-
"repo": "kernel-builder",
|
111 |
-
"type": "github"
|
112 |
-
}
|
113 |
-
},
|
114 |
-
"nixpkgs": {
|
115 |
-
"locked": {
|
116 |
-
"lastModified": 1747820358,
|
117 |
-
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
118 |
-
"owner": "danieldk",
|
119 |
-
"repo": "nixpkgs",
|
120 |
-
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
121 |
-
"type": "github"
|
122 |
-
},
|
123 |
-
"original": {
|
124 |
-
"owner": "danieldk",
|
125 |
-
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
-
"repo": "nixpkgs",
|
127 |
-
"type": "github"
|
128 |
-
}
|
129 |
-
},
|
130 |
-
"root": {
|
131 |
-
"inputs": {
|
132 |
-
"kernel-builder": "kernel-builder"
|
133 |
-
}
|
134 |
-
},
|
135 |
-
"systems": {
|
136 |
-
"locked": {
|
137 |
-
"lastModified": 1681028828,
|
138 |
-
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
139 |
-
"owner": "nix-systems",
|
140 |
-
"repo": "default",
|
141 |
-
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
142 |
-
"type": "github"
|
143 |
-
},
|
144 |
-
"original": {
|
145 |
-
"owner": "nix-systems",
|
146 |
-
"repo": "default",
|
147 |
-
"type": "github"
|
148 |
-
}
|
149 |
-
},
|
150 |
-
"systems_2": {
|
151 |
-
"locked": {
|
152 |
-
"lastModified": 1681028828,
|
153 |
-
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
154 |
-
"owner": "nix-systems",
|
155 |
-
"repo": "default",
|
156 |
-
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
157 |
-
"type": "github"
|
158 |
-
},
|
159 |
-
"original": {
|
160 |
-
"owner": "nix-systems",
|
161 |
-
"repo": "default",
|
162 |
-
"type": "github"
|
163 |
-
}
|
164 |
-
}
|
165 |
-
},
|
166 |
-
"root": "root",
|
167 |
-
"version": 7
|
168 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/flake.nix
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
{
|
2 |
-
description = "Flake for Torch kernel extension";
|
3 |
-
inputs = {
|
4 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
5 |
-
};
|
6 |
-
outputs = { self, kernel-builder, }:
|
7 |
-
kernel-builder.lib.genFlakeOutputs {
|
8 |
-
path = ./.;
|
9 |
-
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
10 |
-
};
|
11 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/__init__.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .poly_norm import PolyNormFunction
|
|
|
6 |
|
7 |
|
8 |
def poly_norm(
|
@@ -14,6 +15,14 @@ def poly_norm(
|
|
14 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
__all__ = [
|
18 |
"poly_norm",
|
19 |
"layers",
|
|
|
3 |
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .poly_norm import PolyNormFunction
|
6 |
+
from .rms_norm import RMSNormFunction
|
7 |
|
8 |
|
9 |
def poly_norm(
|
|
|
15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
16 |
|
17 |
|
18 |
+
def rms_norm(
|
19 |
+
x: torch.Tensor,
|
20 |
+
weight: torch.Tensor,
|
21 |
+
eps: float = 1e-6,
|
22 |
+
) -> None:
|
23 |
+
return RMSNormFunction.apply(x, weight, eps)
|
24 |
+
|
25 |
+
|
26 |
__all__ = [
|
27 |
"poly_norm",
|
28 |
"layers",
|
build/{torch27-cxx11-rocm63-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch26-cxx11-rocm62-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:179bfe6bd5484e81b1d8fa6cc3e2596837946a17f0761b0bb2521fd162669046
|
3 |
+
size 2656296
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _activation_d14fd4d_dirty
|
3 |
+
ops = torch.ops._activation_d14fd4d_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_activation_d14fd4d_dirty::{op_name}"
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/layers.py
CHANGED
@@ -2,13 +2,14 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .poly_norm import PolyNormFunction
|
|
|
5 |
|
6 |
|
7 |
class PolyNorm(nn.Module):
|
8 |
-
def __init__(self, eps=1e-6):
|
9 |
super().__init__()
|
10 |
-
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
11 |
-
self.bias = torch.nn.Parameter(torch.zeros(1))
|
12 |
self.eps = eps
|
13 |
|
14 |
def forward(
|
@@ -16,3 +17,16 @@ class PolyNorm(nn.Module):
|
|
16 |
x: torch.Tensor,
|
17 |
):
|
18 |
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .poly_norm import PolyNormFunction
|
5 |
+
from .rms_norm import RMSNormFunction
|
6 |
|
7 |
|
8 |
class PolyNorm(nn.Module):
|
9 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
10 |
super().__init__()
|
11 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
12 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
13 |
self.eps = eps
|
14 |
|
15 |
def forward(
|
|
|
17 |
x: torch.Tensor,
|
18 |
):
|
19 |
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
20 |
+
|
21 |
+
|
22 |
+
class RMSNorm(nn.Module):
|
23 |
+
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
24 |
+
super().__init__()
|
25 |
+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
26 |
+
self.eps = eps
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self,
|
30 |
+
x: torch.Tensor,
|
31 |
+
):
|
32 |
+
return RMSNormFunction.apply(x, self.weight, self.eps)
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/rms_norm.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ._ops import ops
|
4 |
+
|
5 |
+
|
6 |
+
# Inherit from Function
|
7 |
+
class RMSNormFunction(torch.autograd.Function):
|
8 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
9 |
+
@staticmethod
|
10 |
+
def forward(input, weight, eps):
|
11 |
+
output = torch.empty_like(input)
|
12 |
+
ops.rms_norm(output, input, weight, eps)
|
13 |
+
return output
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
17 |
+
# output is the output of the forward().
|
18 |
+
def setup_context(ctx, inputs, output):
|
19 |
+
input, weight, eps = inputs
|
20 |
+
ctx.save_for_backward(input, weight)
|
21 |
+
ctx.eps = eps
|
22 |
+
|
23 |
+
# This function has only a single output, so it gets only one gradient
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, output_grad):
|
26 |
+
input, weight = ctx.saved_tensors
|
27 |
+
eps = ctx.eps
|
28 |
+
|
29 |
+
input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
|
30 |
+
weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
|
31 |
+
|
32 |
+
ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
|
33 |
+
|
34 |
+
return input_grad, weight_grad, None
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .poly_norm import PolyNormFunction
|
|
|
6 |
|
7 |
|
8 |
def poly_norm(
|
@@ -14,6 +15,14 @@ def poly_norm(
|
|
14 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
__all__ = [
|
18 |
"poly_norm",
|
19 |
"layers",
|
|
|
3 |
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .poly_norm import PolyNormFunction
|
6 |
+
from .rms_norm import RMSNormFunction
|
7 |
|
8 |
|
9 |
def poly_norm(
|
|
|
15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
16 |
|
17 |
|
18 |
+
def rms_norm(
|
19 |
+
x: torch.Tensor,
|
20 |
+
weight: torch.Tensor,
|
21 |
+
eps: float = 1e-6,
|
22 |
+
) -> None:
|
23 |
+
return RMSNormFunction.apply(x, weight, eps)
|
24 |
+
|
25 |
+
|
26 |
__all__ = [
|
27 |
"poly_norm",
|
28 |
"layers",
|
build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:94debfd52e15f782eb9dd328d9311080d803276745e440b176b20a7031299e3f
|
3 |
+
size 2642736
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _activation_d14fd4d_dirty
|
3 |
+
ops = torch.ops._activation_d14fd4d_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_activation_d14fd4d_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py
CHANGED
@@ -2,13 +2,14 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .poly_norm import PolyNormFunction
|
|
|
5 |
|
6 |
|
7 |
class PolyNorm(nn.Module):
|
8 |
-
def __init__(self, eps=1e-6):
|
9 |
super().__init__()
|
10 |
-
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
11 |
-
self.bias = torch.nn.Parameter(torch.zeros(1))
|
12 |
self.eps = eps
|
13 |
|
14 |
def forward(
|
@@ -16,3 +17,16 @@ class PolyNorm(nn.Module):
|
|
16 |
x: torch.Tensor,
|
17 |
):
|
18 |
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .poly_norm import PolyNormFunction
|
5 |
+
from .rms_norm import RMSNormFunction
|
6 |
|
7 |
|
8 |
class PolyNorm(nn.Module):
|
9 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
10 |
super().__init__()
|
11 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
12 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
13 |
self.eps = eps
|
14 |
|
15 |
def forward(
|
|
|
17 |
x: torch.Tensor,
|
18 |
):
|
19 |
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
20 |
+
|
21 |
+
|
22 |
+
class RMSNorm(nn.Module):
|
23 |
+
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
24 |
+
super().__init__()
|
25 |
+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
26 |
+
self.eps = eps
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self,
|
30 |
+
x: torch.Tensor,
|
31 |
+
):
|
32 |
+
return RMSNormFunction.apply(x, self.weight, self.eps)
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ._ops import ops
|
4 |
+
|
5 |
+
|
6 |
+
# Inherit from Function
|
7 |
+
class RMSNormFunction(torch.autograd.Function):
|
8 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
9 |
+
@staticmethod
|
10 |
+
def forward(input, weight, eps):
|
11 |
+
output = torch.empty_like(input)
|
12 |
+
ops.rms_norm(output, input, weight, eps)
|
13 |
+
return output
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
17 |
+
# output is the output of the forward().
|
18 |
+
def setup_context(ctx, inputs, output):
|
19 |
+
input, weight, eps = inputs
|
20 |
+
ctx.save_for_backward(input, weight)
|
21 |
+
ctx.eps = eps
|
22 |
+
|
23 |
+
# This function has only a single output, so it gets only one gradient
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, output_grad):
|
26 |
+
input, weight = ctx.saved_tensors
|
27 |
+
eps = ctx.eps
|
28 |
+
|
29 |
+
input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
|
30 |
+
weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
|
31 |
+
|
32 |
+
ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
|
33 |
+
|
34 |
+
return input_grad, weight_grad, None
|
tests/conftest.py
CHANGED
@@ -4,7 +4,7 @@ import numpy as np
|
|
4 |
import plotly.graph_objects as go
|
5 |
import pytest
|
6 |
|
7 |
-
from .kernels.
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
DO_PLOT = False
|
|
|
4 |
import plotly.graph_objects as go
|
5 |
import pytest
|
6 |
|
7 |
+
from .kernels.test_poly_norm_perf import PERF_RESULTS, PerfResult
|
8 |
|
9 |
logger = logging.getLogger(__name__)
|
10 |
DO_PLOT = False
|
tests/kernels/{test_activation.py → test_poly_norm.py}
RENAMED
File without changes
|
tests/kernels/{test_perf.py → test_poly_norm_perf.py}
RENAMED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
|
7 |
import activation
|
8 |
|
9 |
-
from .
|
10 |
from .utils import assert_close
|
11 |
|
12 |
CASES = [
|
|
|
6 |
|
7 |
import activation
|
8 |
|
9 |
+
from .test_poly_norm import poly_norm
|
10 |
from .utils import assert_close
|
11 |
|
12 |
CASES = [
|
tests/kernels/test_rms_norm.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import pytest
|
4 |
+
import torch
|
5 |
+
|
6 |
+
import activation
|
7 |
+
|
8 |
+
from .utils import assert_close, opcheck
|
9 |
+
|
10 |
+
DTYPES = [torch.float, torch.bfloat16, torch.half]
|
11 |
+
# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
12 |
+
# D = [512, 13824] # Arbitrary values for testing
|
13 |
+
NUM_TOKENS = [7, 13] # Arbitrary values for testing
|
14 |
+
D = [513] # Arbitrary values for testing
|
15 |
+
SEEDS = [0]
|
16 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
17 |
+
|
18 |
+
|
19 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
20 |
+
@pytest.mark.parametrize("d", D)
|
21 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
22 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
23 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
24 |
+
def test_rms_norm(
|
25 |
+
num_tokens: int,
|
26 |
+
d: int,
|
27 |
+
dtype: torch.dtype,
|
28 |
+
seed: int,
|
29 |
+
device: str,
|
30 |
+
) -> None:
|
31 |
+
random.seed(seed)
|
32 |
+
torch.manual_seed(seed)
|
33 |
+
torch.set_default_device(device)
|
34 |
+
|
35 |
+
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
|
36 |
+
weight = torch.randn(d, dtype=dtype, requires_grad=True)
|
37 |
+
eps = 1e-05
|
38 |
+
|
39 |
+
x.retain_grad()
|
40 |
+
weight.retain_grad()
|
41 |
+
# To separate gradient computation, clone the inputs
|
42 |
+
|
43 |
+
x_ref = x.detach().clone().requires_grad_(True)
|
44 |
+
weight_ref = weight.detach().clone().requires_grad_(True)
|
45 |
+
|
46 |
+
torch_layer = torch.nn.RMSNorm(d, eps=eps, dtype=dtype)
|
47 |
+
torch_layer.weight = torch.nn.Parameter(weight_ref)
|
48 |
+
|
49 |
+
op = activation.ops.rms_norm
|
50 |
+
fn = activation.rms_norm
|
51 |
+
layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
|
52 |
+
layer.weight = torch.nn.Parameter(weight)
|
53 |
+
|
54 |
+
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
55 |
+
opcheck(op, (out, x, weight, eps))
|
56 |
+
|
57 |
+
out = fn(x, weight, eps)
|
58 |
+
mod_out = layer(x)
|
59 |
+
ref_out = torch_layer(x_ref)
|
60 |
+
|
61 |
+
assert_close(out, ref_out)
|
62 |
+
assert_close(mod_out, out, atol=0.0, rtol=0.0)
|
63 |
+
|
64 |
+
# test backward pass
|
65 |
+
out_grad = torch.randn_like(out)
|
66 |
+
out_grad = out_grad / out_grad.norm()
|
67 |
+
|
68 |
+
ref_out.backward(out_grad)
|
69 |
+
mod_out.backward(out_grad)
|
70 |
+
|
71 |
+
assert_close(x.grad, x_ref.grad)
|
72 |
+
assert_close(layer.weight.grad, torch_layer.weight.grad, rtol=0.05)
|
torch-ext/activation/__init__.py
CHANGED
@@ -3,6 +3,7 @@ import torch
|
|
3 |
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .poly_norm import PolyNormFunction
|
|
|
6 |
|
7 |
|
8 |
def poly_norm(
|
@@ -14,6 +15,14 @@ def poly_norm(
|
|
14 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
__all__ = [
|
18 |
"poly_norm",
|
19 |
"layers",
|
|
|
3 |
from . import layers
|
4 |
from ._ops import ops
|
5 |
from .poly_norm import PolyNormFunction
|
6 |
+
from .rms_norm import RMSNormFunction
|
7 |
|
8 |
|
9 |
def poly_norm(
|
|
|
15 |
return PolyNormFunction.apply(x, weight, bias, eps)
|
16 |
|
17 |
|
18 |
+
def rms_norm(
|
19 |
+
x: torch.Tensor,
|
20 |
+
weight: torch.Tensor,
|
21 |
+
eps: float = 1e-6,
|
22 |
+
) -> None:
|
23 |
+
return RMSNormFunction.apply(x, weight, eps)
|
24 |
+
|
25 |
+
|
26 |
__all__ = [
|
27 |
"poly_norm",
|
28 |
"layers",
|
torch-ext/activation/layers.py
CHANGED
@@ -2,13 +2,14 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .poly_norm import PolyNormFunction
|
|
|
5 |
|
6 |
|
7 |
class PolyNorm(nn.Module):
|
8 |
-
def __init__(self, eps=1e-6):
|
9 |
super().__init__()
|
10 |
-
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
11 |
-
self.bias = torch.nn.Parameter(torch.zeros(1))
|
12 |
self.eps = eps
|
13 |
|
14 |
def forward(
|
@@ -16,3 +17,16 @@ class PolyNorm(nn.Module):
|
|
16 |
x: torch.Tensor,
|
17 |
):
|
18 |
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .poly_norm import PolyNormFunction
|
5 |
+
from .rms_norm import RMSNormFunction
|
6 |
|
7 |
|
8 |
class PolyNorm(nn.Module):
|
9 |
+
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
10 |
super().__init__()
|
11 |
+
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
12 |
+
self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
|
13 |
self.eps = eps
|
14 |
|
15 |
def forward(
|
|
|
17 |
x: torch.Tensor,
|
18 |
):
|
19 |
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
20 |
+
|
21 |
+
|
22 |
+
class RMSNorm(nn.Module):
|
23 |
+
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
24 |
+
super().__init__()
|
25 |
+
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
26 |
+
self.eps = eps
|
27 |
+
|
28 |
+
def forward(
|
29 |
+
self,
|
30 |
+
x: torch.Tensor,
|
31 |
+
):
|
32 |
+
return RMSNormFunction.apply(x, self.weight, self.eps)
|
torch-ext/activation/rms_norm.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ._ops import ops
|
4 |
+
|
5 |
+
|
6 |
+
# Inherit from Function
|
7 |
+
class RMSNormFunction(torch.autograd.Function):
|
8 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
9 |
+
@staticmethod
|
10 |
+
def forward(input, weight, eps):
|
11 |
+
output = torch.empty_like(input)
|
12 |
+
ops.rms_norm(output, input, weight, eps)
|
13 |
+
return output
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
17 |
+
# output is the output of the forward().
|
18 |
+
def setup_context(ctx, inputs, output):
|
19 |
+
input, weight, eps = inputs
|
20 |
+
ctx.save_for_backward(input, weight)
|
21 |
+
ctx.eps = eps
|
22 |
+
|
23 |
+
# This function has only a single output, so it gets only one gradient
|
24 |
+
@staticmethod
|
25 |
+
def backward(ctx, output_grad):
|
26 |
+
input, weight = ctx.saved_tensors
|
27 |
+
eps = ctx.eps
|
28 |
+
|
29 |
+
input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
|
30 |
+
weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
|
31 |
+
|
32 |
+
ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
|
33 |
+
|
34 |
+
return input_grad, weight_grad, None
|
torch-ext/torch_binding.cpp
CHANGED
@@ -9,6 +9,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|
9 |
ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
|
10 |
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
11 |
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
}
|
13 |
|
14 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
9 |
ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
|
10 |
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
11 |
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
12 |
+
|
13 |
+
// Activation ops
|
14 |
+
ops.def("rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
|
15 |
+
ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
|
16 |
+
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
17 |
+
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
18 |
+
|
19 |
+
|
20 |
}
|
21 |
|
22 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
CHANGED
@@ -4,3 +4,6 @@
|
|
4 |
|
5 |
void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps);
|
6 |
void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
|
|
|
|
|
|
|
|
4 |
|
5 |
void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps);
|
6 |
void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
|
7 |
+
|
8 |
+
void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, double eps);
|
9 |
+
void rms_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
|