iamwyldecat commited on
Commit
f3b99fb
·
1 Parent(s): d14fd4d

feat(rms-norm): Impl fused RMSNorm

Browse files
Files changed (26) hide show
  1. README.md +1 -0
  2. activation/block_reduce.h +20 -0
  3. activation/{activation_kernels.cu → poly_norm.cu} +9 -20
  4. activation/rms_norm.cu +168 -0
  5. build.toml +3 -1
  6. build/flake.lock +0 -168
  7. build/flake.nix +0 -11
  8. build/torch26-cxx11-rocm62-x86_64-linux/activation/__init__.py +9 -0
  9. build/{torch27-cxx11-rocm63-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch26-cxx11-rocm62-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so} +2 -2
  10. build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py +3 -3
  11. build/torch26-cxx11-rocm62-x86_64-linux/activation/layers.py +17 -3
  12. build/torch26-cxx11-rocm62-x86_64-linux/activation/rms_norm.py +34 -0
  13. build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +9 -0
  14. build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so} +2 -2
  15. build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
  16. build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +17 -3
  17. build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +34 -0
  18. tests/conftest.py +1 -1
  19. tests/kernels/{test_activation.py → test_poly_norm.py} +0 -0
  20. tests/kernels/{test_perf.py → test_poly_norm_perf.py} +1 -1
  21. tests/kernels/test_rms_norm.py +72 -0
  22. torch-ext/activation/__init__.py +9 -0
  23. torch-ext/activation/layers.py +17 -3
  24. torch-ext/activation/rms_norm.py +34 -0
  25. torch-ext/torch_binding.cpp +8 -0
  26. torch-ext/torch_binding.h +3 -0
README.md CHANGED
@@ -9,6 +9,7 @@ Activation is a python package that contains custom CUDA-based activation kernel
9
 
10
  - Currently implemented
11
  - [PolyNorm](https://arxiv.org/html/2411.03884v1)
 
12
 
13
  ## Usage
14
 
 
9
 
10
  - Currently implemented
11
  - [PolyNorm](https://arxiv.org/html/2411.03884v1)
12
+ - [RMSNorm](https://docs.pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html)
13
 
14
  ## Usage
15
 
activation/block_reduce.h ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ namespace motif {
2
+
3
+ template <typename acc_t, int BLOCK_SIZE>
4
+ __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) {
5
+ // TODO: Optimize with warp-level primitives
6
+ __syncthreads();
7
+
8
+ shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
9
+ __syncthreads();
10
+ for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
11
+ if (threadIdx.x < stride) {
12
+ shared[threadIdx.x] += shared[threadIdx.x + stride];
13
+ }
14
+ __syncthreads();
15
+ }
16
+
17
+ return shared[0];
18
+ }
19
+
20
+ } // motif
activation/{activation_kernels.cu → poly_norm.cu} RENAMED
@@ -9,26 +9,10 @@
9
  #include "dispatch_utils.h"
10
  #include "assert_utils.h"
11
  #include "atomic_utils.h"
 
12
 
13
  namespace motif {
14
 
15
- template <typename acc_t, int BLOCK_SIZE>
16
- __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) {
17
- // TODO: Optimize with warp-level primitives
18
- __syncthreads();
19
-
20
- shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
21
- __syncthreads();
22
- for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
23
- if (threadIdx.x < stride) {
24
- shared[threadIdx.x] += shared[threadIdx.x + stride];
25
- }
26
- __syncthreads();
27
- }
28
-
29
- return shared[0];
30
- }
31
-
32
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
33
  __global__ void poly_norm_kernel(
34
  scalar_t* __restrict__ out, // [..., d]
@@ -251,7 +235,12 @@ void poly_norm_backward(
251
  }
252
  );
253
 
254
- at::sum_out(bias_grad, output_grad);
255
- at::sum_out(weight_grad, temp_weight_grad, {0});
256
- bias_grad.resize_({1});
 
 
 
 
 
257
  }
 
9
  #include "dispatch_utils.h"
10
  #include "assert_utils.h"
11
  #include "atomic_utils.h"
12
+ #include "block_reduce.h"
13
 
14
  namespace motif {
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
  __global__ void poly_norm_kernel(
18
  scalar_t* __restrict__ out, // [..., d]
 
235
  }
236
  );
237
 
238
+ if (bias_grad.defined()) {
239
+ at::sum_out(bias_grad, output_grad);
240
+ bias_grad.resize_({1});
241
+ }
242
+
243
+ if (weight_grad.defined()) {
244
+ at::sum_out(weight_grad, temp_weight_grad, {0});
245
+ }
246
  }
activation/rms_norm.cu ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <ATen/Functions.h>
3
+ #include <torch/all.h>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ #include <cmath>
7
+
8
+ #include "cuda_compat.h"
9
+ #include "dispatch_utils.h"
10
+ #include "assert_utils.h"
11
+ #include "atomic_utils.h"
12
+ #include "block_reduce.h"
13
+
14
+ namespace motif {
15
+
16
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
+ __global__ void rms_norm_kernel(
18
+ scalar_t* __restrict__ out, // [..., d]
19
+ const scalar_t* __restrict__ input, // [..., d]
20
+ const scalar_t* __restrict__ weight, // [d]
21
+ const float eps,
22
+ const int d
23
+ ) {
24
+
25
+ const int64_t token_idx = blockIdx.x;
26
+ const int64_t vec_idx = threadIdx.x;
27
+ acc_t sum_square = 0.0f;
28
+
29
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
30
+ acc_t x = input[token_idx * d + idx];
31
+ sum_square += x * x;
32
+ }
33
+
34
+ __shared__ acc_t shared[BLOCK_SIZE];
35
+
36
+ acc_t variance =
37
+ _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
38
+ acc_t scale = rsqrt(variance + eps);
39
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
40
+ acc_t x = input[token_idx * d + idx];
41
+ acc_t w = weight[idx];
42
+ out[token_idx * d + idx] = w * x * scale;
43
+ }
44
+ }
45
+
46
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
47
+ __global__ void rms_norm_backward_kernel(
48
+ scalar_t* __restrict__ input_grad, // [..., d]
49
+ acc_t* __restrict__ temp_weight_grad, // [..., d]
50
+ const scalar_t* __restrict__ output_grad, // [..., d]
51
+ const scalar_t* __restrict__ input, // [..., d]
52
+ const scalar_t* __restrict__ weight, // [d]
53
+ const float eps,
54
+ const int d
55
+ ) {
56
+ const int64_t token_idx = blockIdx.x;
57
+ const int64_t vec_idx = threadIdx.x;
58
+ acc_t d_sum = 0.0f;
59
+ acc_t sum_square = 0.0f;
60
+
61
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
62
+ acc_t x = input[token_idx * d + idx];
63
+ acc_t dy = output_grad[token_idx * d + idx];
64
+ acc_t w = weight[idx];
65
+ d_sum += dy * x * w;
66
+ sum_square += x * x;
67
+ }
68
+
69
+ __shared__ acc_t shared[BLOCK_SIZE];
70
+
71
+ d_sum = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, d_sum, d);
72
+ acc_t variance =
73
+ _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
74
+ acc_t scale = rsqrt(variance + eps);
75
+ acc_t scale_cubed = scale * scale * scale;
76
+ acc_t dxx = d_sum * scale_cubed / d;
77
+
78
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
79
+ acc_t x = input[token_idx * d + idx];
80
+ acc_t dy = output_grad[token_idx * d + idx];
81
+ acc_t w = weight[idx];
82
+
83
+ input_grad[token_idx * d + idx] =
84
+ scale * dy * w - dxx * x;
85
+
86
+ if (temp_weight_grad) {
87
+ temp_weight_grad[token_idx * d + idx] = dy * x * scale;
88
+ }
89
+ }
90
+ }
91
+
92
+ } // namespace motif
93
+
94
+
95
+ void rms_norm(torch::Tensor& out, // [..., d]
96
+ const torch::Tensor& input, // [..., d]
97
+ const torch::Tensor& weight, // [d]
98
+ double eps)
99
+ {
100
+ AssertTensorShapeEqual(input, out, "input", "out");
101
+ AssertTensorNotNull(weight, "weight");
102
+ // TODO shape check
103
+
104
+ constexpr int BLOCK_SIZE = 256;
105
+
106
+ int d = input.size(-1);
107
+ int64_t num_tokens = input.numel() / input.size(-1);
108
+ dim3 grid(num_tokens);
109
+ dim3 block(BLOCK_SIZE);
110
+
111
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
112
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
113
+ MOTIF_DISPATCH_FLOATING_TYPES(
114
+ input.scalar_type(), "rms_norm_kernel", [&] {
115
+ motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
116
+ <<<grid, block, 0, stream>>>(
117
+ out.data_ptr<scalar_t>(),
118
+ input.data_ptr<scalar_t>(),
119
+ weight.data_ptr<scalar_t>(),
120
+ eps, d);
121
+ }
122
+ );
123
+ }
124
+
125
+ void rms_norm_backward(
126
+ torch::Tensor& input_grad, // [..., d]
127
+ torch::Tensor& weight_grad, // [..., d]
128
+ const torch::Tensor& output_grad, // [d]
129
+ const torch::Tensor& input, // [d]
130
+ const torch::Tensor& weight, // [d]
131
+ double eps) {
132
+ AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
133
+ AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
134
+ AssertTensorNotNull(weight, "weight");
135
+ // TODO shape check
136
+ // weight_grad, input_grad can be nullable
137
+
138
+ constexpr int BLOCK_SIZE = 256;
139
+
140
+ int d = input.size(-1);
141
+ int64_t num_tokens = input.numel() / input.size(-1);
142
+ dim3 grid(num_tokens);
143
+ dim3 block(BLOCK_SIZE);
144
+
145
+ torch::Tensor temp_weight_grad =
146
+ torch::empty({num_tokens, d},
147
+ input.options().dtype(torch::kFloat));
148
+
149
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150
+
151
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
152
+ MOTIF_DISPATCH_FLOATING_TYPES(
153
+ input.scalar_type(), "rms_norm_backward_kernel", [&] {
154
+ motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
155
+ <<<grid, block, 0, stream>>>(
156
+ input_grad.data_ptr<scalar_t>(),
157
+ temp_weight_grad.data_ptr<float>(),
158
+ output_grad.data_ptr<scalar_t>(),
159
+ input.data_ptr<scalar_t>(),
160
+ weight.data_ptr<scalar_t>(),
161
+ eps, d);
162
+ }
163
+ );
164
+
165
+ if (weight_grad.defined()) {
166
+ at::sum_out(weight_grad, temp_weight_grad, {0});
167
+ }
168
+ }
build.toml CHANGED
@@ -12,8 +12,10 @@ src = [
12
  backend = "rocm"
13
  rocm-archs = [ "gfx90a" ]
14
  src = [
15
- "activation/activation_kernels.cu",
 
16
  "activation/cuda_compat.h",
 
17
  "activation/dispatch_utils.h",
18
  "activation/assert_utils.h",
19
  "activation/atomic_utils.h",
 
12
  backend = "rocm"
13
  rocm-archs = [ "gfx90a" ]
14
  src = [
15
+ "activation/poly_norm.cu",
16
+ "activation/rms_norm.cu",
17
  "activation/cuda_compat.h",
18
+ "activation/block_reduce.h",
19
  "activation/dispatch_utils.h",
20
  "activation/assert_utils.h",
21
  "activation/atomic_utils.h",
build/flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1733328505,
21
- "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1747919133,
77
- "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1748620233,
102
- "narHash": "sha256-VULm9HgGXvo3pyfsPy3SOhoqgkuqbGSaSemvzNUbdIU=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "da3340e5b3cbb6086600420f4814b033395788d1",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
- "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
build/flake.nix DELETED
@@ -1,11 +0,0 @@
1
- {
2
- description = "Flake for Torch kernel extension";
3
- inputs = {
4
- kernel-builder.url = "github:huggingface/kernel-builder";
5
- };
6
- outputs = { self, kernel-builder, }:
7
- kernel-builder.lib.genFlakeOutputs {
8
- path = ./.;
9
- rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
10
- };
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/activation/__init__.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from . import layers
4
  from ._ops import ops
5
  from .poly_norm import PolyNormFunction
 
6
 
7
 
8
  def poly_norm(
@@ -14,6 +15,14 @@ def poly_norm(
14
  return PolyNormFunction.apply(x, weight, bias, eps)
15
 
16
 
 
 
 
 
 
 
 
 
17
  __all__ = [
18
  "poly_norm",
19
  "layers",
 
3
  from . import layers
4
  from ._ops import ops
5
  from .poly_norm import PolyNormFunction
6
+ from .rms_norm import RMSNormFunction
7
 
8
 
9
  def poly_norm(
 
15
  return PolyNormFunction.apply(x, weight, bias, eps)
16
 
17
 
18
+ def rms_norm(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float = 1e-6,
22
+ ) -> None:
23
+ return RMSNormFunction.apply(x, weight, eps)
24
+
25
+
26
  __all__ = [
27
  "poly_norm",
28
  "layers",
build/{torch27-cxx11-rocm63-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch26-cxx11-rocm62-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6fe6163d88e95c0d6847b3fe993cd80de677f89cfde7fc4d5c3ec2d0d96c9de8
3
- size 2395176
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:179bfe6bd5484e81b1d8fa6cc3e2596837946a17f0761b0bb2521fd162669046
3
+ size 2656296
build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_704692b_dirty
3
- ops = torch.ops._activation_704692b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_704692b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_d14fd4d_dirty
3
+ ops = torch.ops._activation_d14fd4d_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_d14fd4d_dirty::{op_name}"
build/torch26-cxx11-rocm62-x86_64-linux/activation/layers.py CHANGED
@@ -2,13 +2,14 @@ import torch
2
  import torch.nn as nn
3
 
4
  from .poly_norm import PolyNormFunction
 
5
 
6
 
7
  class PolyNorm(nn.Module):
8
- def __init__(self, eps=1e-6):
9
  super().__init__()
10
- self.weight = torch.nn.Parameter(torch.ones(3) / 3)
11
- self.bias = torch.nn.Parameter(torch.zeros(1))
12
  self.eps = eps
13
 
14
  def forward(
@@ -16,3 +17,16 @@ class PolyNorm(nn.Module):
16
  x: torch.Tensor,
17
  ):
18
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
 
4
  from .poly_norm import PolyNormFunction
5
+ from .rms_norm import RMSNormFunction
6
 
7
 
8
  class PolyNorm(nn.Module):
9
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
10
  super().__init__()
11
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
12
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
13
  self.eps = eps
14
 
15
  def forward(
 
17
  x: torch.Tensor,
18
  ):
19
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
20
+
21
+
22
+ class RMSNorm(nn.Module):
23
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
24
+ super().__init__()
25
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
26
+ self.eps = eps
27
+
28
+ def forward(
29
+ self,
30
+ x: torch.Tensor,
31
+ ):
32
+ return RMSNormFunction.apply(x, self.weight, self.eps)
build/torch26-cxx11-rocm62-x86_64-linux/activation/rms_norm.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ # Inherit from Function
7
+ class RMSNormFunction(torch.autograd.Function):
8
+ # Note that forward, setup_context, and backward are @staticmethods
9
+ @staticmethod
10
+ def forward(input, weight, eps):
11
+ output = torch.empty_like(input)
12
+ ops.rms_norm(output, input, weight, eps)
13
+ return output
14
+
15
+ @staticmethod
16
+ # inputs is a Tuple of all of the inputs passed to forward.
17
+ # output is the output of the forward().
18
+ def setup_context(ctx, inputs, output):
19
+ input, weight, eps = inputs
20
+ ctx.save_for_backward(input, weight)
21
+ ctx.eps = eps
22
+
23
+ # This function has only a single output, so it gets only one gradient
24
+ @staticmethod
25
+ def backward(ctx, output_grad):
26
+ input, weight = ctx.saved_tensors
27
+ eps = ctx.eps
28
+
29
+ input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
+ weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
+
32
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
33
+
34
+ return input_grad, weight_grad, None
build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from . import layers
4
  from ._ops import ops
5
  from .poly_norm import PolyNormFunction
 
6
 
7
 
8
  def poly_norm(
@@ -14,6 +15,14 @@ def poly_norm(
14
  return PolyNormFunction.apply(x, weight, bias, eps)
15
 
16
 
 
 
 
 
 
 
 
 
17
  __all__ = [
18
  "poly_norm",
19
  "layers",
 
3
  from . import layers
4
  from ._ops import ops
5
  from .poly_norm import PolyNormFunction
6
+ from .rms_norm import RMSNormFunction
7
 
8
 
9
  def poly_norm(
 
15
  return PolyNormFunction.apply(x, weight, bias, eps)
16
 
17
 
18
+ def rms_norm(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float = 1e-6,
22
+ ) -> None:
23
+ return RMSNormFunction.apply(x, weight, eps)
24
+
25
+
26
  __all__ = [
27
  "poly_norm",
28
  "layers",
build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_704692b_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_d14fd4d_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:417cf142fb8234b05f7e5b0be321d3a95ceafd7c0b3e5d3469579a52d78ddb1e
3
- size 2401160
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94debfd52e15f782eb9dd328d9311080d803276745e440b176b20a7031299e3f
3
+ size 2642736
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_704692b_dirty
3
- ops = torch.ops._activation_704692b_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_704692b_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_d14fd4d_dirty
3
+ ops = torch.ops._activation_d14fd4d_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_d14fd4d_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py CHANGED
@@ -2,13 +2,14 @@ import torch
2
  import torch.nn as nn
3
 
4
  from .poly_norm import PolyNormFunction
 
5
 
6
 
7
  class PolyNorm(nn.Module):
8
- def __init__(self, eps=1e-6):
9
  super().__init__()
10
- self.weight = torch.nn.Parameter(torch.ones(3) / 3)
11
- self.bias = torch.nn.Parameter(torch.zeros(1))
12
  self.eps = eps
13
 
14
  def forward(
@@ -16,3 +17,16 @@ class PolyNorm(nn.Module):
16
  x: torch.Tensor,
17
  ):
18
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
 
4
  from .poly_norm import PolyNormFunction
5
+ from .rms_norm import RMSNormFunction
6
 
7
 
8
  class PolyNorm(nn.Module):
9
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
10
  super().__init__()
11
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
12
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
13
  self.eps = eps
14
 
15
  def forward(
 
17
  x: torch.Tensor,
18
  ):
19
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
20
+
21
+
22
+ class RMSNorm(nn.Module):
23
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
24
+ super().__init__()
25
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
26
+ self.eps = eps
27
+
28
+ def forward(
29
+ self,
30
+ x: torch.Tensor,
31
+ ):
32
+ return RMSNormFunction.apply(x, self.weight, self.eps)
build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ # Inherit from Function
7
+ class RMSNormFunction(torch.autograd.Function):
8
+ # Note that forward, setup_context, and backward are @staticmethods
9
+ @staticmethod
10
+ def forward(input, weight, eps):
11
+ output = torch.empty_like(input)
12
+ ops.rms_norm(output, input, weight, eps)
13
+ return output
14
+
15
+ @staticmethod
16
+ # inputs is a Tuple of all of the inputs passed to forward.
17
+ # output is the output of the forward().
18
+ def setup_context(ctx, inputs, output):
19
+ input, weight, eps = inputs
20
+ ctx.save_for_backward(input, weight)
21
+ ctx.eps = eps
22
+
23
+ # This function has only a single output, so it gets only one gradient
24
+ @staticmethod
25
+ def backward(ctx, output_grad):
26
+ input, weight = ctx.saved_tensors
27
+ eps = ctx.eps
28
+
29
+ input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
+ weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
+
32
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
33
+
34
+ return input_grad, weight_grad, None
tests/conftest.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import plotly.graph_objects as go
5
  import pytest
6
 
7
- from .kernels.test_perf import PERF_RESULTS, PerfResult
8
 
9
  logger = logging.getLogger(__name__)
10
  DO_PLOT = False
 
4
  import plotly.graph_objects as go
5
  import pytest
6
 
7
+ from .kernels.test_poly_norm_perf import PERF_RESULTS, PerfResult
8
 
9
  logger = logging.getLogger(__name__)
10
  DO_PLOT = False
tests/kernels/{test_activation.py → test_poly_norm.py} RENAMED
File without changes
tests/kernels/{test_perf.py → test_poly_norm_perf.py} RENAMED
@@ -6,7 +6,7 @@ import torch
6
 
7
  import activation
8
 
9
- from .test_activation import poly_norm
10
  from .utils import assert_close
11
 
12
  CASES = [
 
6
 
7
  import activation
8
 
9
+ from .test_poly_norm import poly_norm
10
  from .utils import assert_close
11
 
12
  CASES = [
tests/kernels/test_rms_norm.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ import activation
7
+
8
+ from .utils import assert_close, opcheck
9
+
10
+ DTYPES = [torch.float, torch.bfloat16, torch.half]
11
+ # NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
12
+ # D = [512, 13824] # Arbitrary values for testing
13
+ NUM_TOKENS = [7, 13] # Arbitrary values for testing
14
+ D = [513] # Arbitrary values for testing
15
+ SEEDS = [0]
16
+ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
17
+
18
+
19
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
20
+ @pytest.mark.parametrize("d", D)
21
+ @pytest.mark.parametrize("dtype", DTYPES)
22
+ @pytest.mark.parametrize("seed", SEEDS)
23
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
24
+ def test_rms_norm(
25
+ num_tokens: int,
26
+ d: int,
27
+ dtype: torch.dtype,
28
+ seed: int,
29
+ device: str,
30
+ ) -> None:
31
+ random.seed(seed)
32
+ torch.manual_seed(seed)
33
+ torch.set_default_device(device)
34
+
35
+ x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
36
+ weight = torch.randn(d, dtype=dtype, requires_grad=True)
37
+ eps = 1e-05
38
+
39
+ x.retain_grad()
40
+ weight.retain_grad()
41
+ # To separate gradient computation, clone the inputs
42
+
43
+ x_ref = x.detach().clone().requires_grad_(True)
44
+ weight_ref = weight.detach().clone().requires_grad_(True)
45
+
46
+ torch_layer = torch.nn.RMSNorm(d, eps=eps, dtype=dtype)
47
+ torch_layer.weight = torch.nn.Parameter(weight_ref)
48
+
49
+ op = activation.ops.rms_norm
50
+ fn = activation.rms_norm
51
+ layer = activation.layers.RMSNorm(d, eps=eps, dtype=dtype)
52
+ layer.weight = torch.nn.Parameter(weight)
53
+
54
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
55
+ opcheck(op, (out, x, weight, eps))
56
+
57
+ out = fn(x, weight, eps)
58
+ mod_out = layer(x)
59
+ ref_out = torch_layer(x_ref)
60
+
61
+ assert_close(out, ref_out)
62
+ assert_close(mod_out, out, atol=0.0, rtol=0.0)
63
+
64
+ # test backward pass
65
+ out_grad = torch.randn_like(out)
66
+ out_grad = out_grad / out_grad.norm()
67
+
68
+ ref_out.backward(out_grad)
69
+ mod_out.backward(out_grad)
70
+
71
+ assert_close(x.grad, x_ref.grad)
72
+ assert_close(layer.weight.grad, torch_layer.weight.grad, rtol=0.05)
torch-ext/activation/__init__.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from . import layers
4
  from ._ops import ops
5
  from .poly_norm import PolyNormFunction
 
6
 
7
 
8
  def poly_norm(
@@ -14,6 +15,14 @@ def poly_norm(
14
  return PolyNormFunction.apply(x, weight, bias, eps)
15
 
16
 
 
 
 
 
 
 
 
 
17
  __all__ = [
18
  "poly_norm",
19
  "layers",
 
3
  from . import layers
4
  from ._ops import ops
5
  from .poly_norm import PolyNormFunction
6
+ from .rms_norm import RMSNormFunction
7
 
8
 
9
  def poly_norm(
 
15
  return PolyNormFunction.apply(x, weight, bias, eps)
16
 
17
 
18
+ def rms_norm(
19
+ x: torch.Tensor,
20
+ weight: torch.Tensor,
21
+ eps: float = 1e-6,
22
+ ) -> None:
23
+ return RMSNormFunction.apply(x, weight, eps)
24
+
25
+
26
  __all__ = [
27
  "poly_norm",
28
  "layers",
torch-ext/activation/layers.py CHANGED
@@ -2,13 +2,14 @@ import torch
2
  import torch.nn as nn
3
 
4
  from .poly_norm import PolyNormFunction
 
5
 
6
 
7
  class PolyNorm(nn.Module):
8
- def __init__(self, eps=1e-6):
9
  super().__init__()
10
- self.weight = torch.nn.Parameter(torch.ones(3) / 3)
11
- self.bias = torch.nn.Parameter(torch.zeros(1))
12
  self.eps = eps
13
 
14
  def forward(
@@ -16,3 +17,16 @@ class PolyNorm(nn.Module):
16
  x: torch.Tensor,
17
  ):
18
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
 
4
  from .poly_norm import PolyNormFunction
5
+ from .rms_norm import RMSNormFunction
6
 
7
 
8
  class PolyNorm(nn.Module):
9
+ def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
10
  super().__init__()
11
+ self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
12
+ self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
13
  self.eps = eps
14
 
15
  def forward(
 
17
  x: torch.Tensor,
18
  ):
19
  return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
20
+
21
+
22
+ class RMSNorm(nn.Module):
23
+ def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
24
+ super().__init__()
25
+ self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
26
+ self.eps = eps
27
+
28
+ def forward(
29
+ self,
30
+ x: torch.Tensor,
31
+ ):
32
+ return RMSNormFunction.apply(x, self.weight, self.eps)
torch-ext/activation/rms_norm.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ # Inherit from Function
7
+ class RMSNormFunction(torch.autograd.Function):
8
+ # Note that forward, setup_context, and backward are @staticmethods
9
+ @staticmethod
10
+ def forward(input, weight, eps):
11
+ output = torch.empty_like(input)
12
+ ops.rms_norm(output, input, weight, eps)
13
+ return output
14
+
15
+ @staticmethod
16
+ # inputs is a Tuple of all of the inputs passed to forward.
17
+ # output is the output of the forward().
18
+ def setup_context(ctx, inputs, output):
19
+ input, weight, eps = inputs
20
+ ctx.save_for_backward(input, weight)
21
+ ctx.eps = eps
22
+
23
+ # This function has only a single output, so it gets only one gradient
24
+ @staticmethod
25
+ def backward(ctx, output_grad):
26
+ input, weight = ctx.saved_tensors
27
+ eps = ctx.eps
28
+
29
+ input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
+ weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
+
32
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
33
+
34
+ return input_grad, weight_grad, None
torch-ext/torch_binding.cpp CHANGED
@@ -9,6 +9,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
9
  ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
10
  ops.impl("poly_norm", torch::kCUDA, &poly_norm);
11
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
 
 
 
 
 
 
 
 
12
  }
13
 
14
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
9
  ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
10
  ops.impl("poly_norm", torch::kCUDA, &poly_norm);
11
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
12
+
13
+ // Activation ops
14
+ ops.def("rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
15
+ ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
16
+ ops.impl("rms_norm", torch::kCUDA, &rms_norm);
17
+ ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
18
+
19
+
20
  }
21
 
22
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -4,3 +4,6 @@
4
 
5
  void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps);
6
  void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
 
 
 
 
4
 
5
  void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps);
6
  void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
7
+
8
+ void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, double eps);
9
+ void rms_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);