Kernels
iamwyldecat commited on
Commit
44e9845
·
1 Parent(s): 7a7d761

feat(poly-norm): Add PolyNorm

Browse files
.gitignore ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/vim,python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=vim,python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ ### Python Patch ###
163
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
164
+ poetry.toml
165
+
166
+ # ruff
167
+ .ruff_cache/
168
+
169
+ # LSP config files
170
+ pyrightconfig.json
171
+
172
+ ### Vim ###
173
+ # Swap
174
+ [._]*.s[a-v][a-z]
175
+ !*.svg # comment out if you don't need vector files
176
+ [._]*.sw[a-p]
177
+ [._]s[a-rt-v][a-z]
178
+ [._]ss[a-gi-z]
179
+ [._]sw[a-p]
180
+
181
+ # Session
182
+ Session.vim
183
+ Sessionx.vim
184
+
185
+ # Temporary
186
+ .netrwhist
187
+ *~
188
+ # Auto-generated tag files
189
+ tags
190
+ # Persistent undo
191
+ [._]*.un~
192
+
193
+ # End of https://www.toptal.com/developers/gitignore/api/vim,python
README.md ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - kernel
4
+ ---
activation/activation_kernels.cu ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <torch/all.h>
3
+ #include <c10/cuda/CUDAGuard.h>
4
+
5
+ #include <cmath>
6
+
7
+ #include "cuda_compat.h"
8
+ #include "dispatch_utils.h"
9
+ #include "assert_utils.h"
10
+ #include "atomic_utils.h"
11
+
12
+ namespace motif {
13
+
14
+ template <typename acc_t, int BLOCK_SIZE>
15
+ __device__ acc_t _block_reduce_sum(volatile acc_t* shared, const float val, const int d) {
16
+ // TODO: Optimize with warp-level primitives
17
+ shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
18
+ __syncthreads();
19
+ for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
20
+ if (threadIdx.x < stride) {
21
+ shared[threadIdx.x] += shared[threadIdx.x + stride];
22
+ }
23
+ __syncthreads();
24
+ }
25
+
26
+ return shared[0];
27
+ }
28
+
29
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
30
+ __global__ void poly_norm_kernel(
31
+ scalar_t* __restrict__ out, // [..., d]
32
+ const scalar_t* __restrict__ input, // [..., d]
33
+ const scalar_t* __restrict__ weight, // [3]
34
+ const scalar_t* __restrict__ bias, // [1]
35
+ const float eps,
36
+ const int d
37
+ ) {
38
+ const int64_t token_idx = blockIdx.x;
39
+
40
+ acc_t sum = 0.0f;
41
+ acc_t sum_square = 0.0f;
42
+ acc_t sum_cube = 0.0f;
43
+
44
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
45
+ acc_t x = input[token_idx * d + idx];
46
+ sum += pow(x, 2.0f);
47
+ sum_square += pow(x, 4.0f);
48
+ sum_cube += pow(x, 6.0f);
49
+ }
50
+
51
+ __shared__ acc_t shared[BLOCK_SIZE];
52
+
53
+ acc_t mean = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum, d) / d;
54
+ acc_t mean_square = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
55
+ acc_t mean_cube = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_cube, d) / d;
56
+
57
+ acc_t w0 = weight[0];
58
+ acc_t w1 = weight[1];
59
+ acc_t w2 = weight[2];
60
+ acc_t b = bias[0];
61
+
62
+ acc_t divisor = sqrt(mean + eps);
63
+ acc_t divisor_square = sqrt(mean_square + eps);
64
+ acc_t divisor_cube = sqrt(mean_cube + eps);
65
+
66
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
67
+ acc_t x = input[token_idx * d + idx];
68
+ acc_t x_square = pow(x, 2.0f);
69
+ acc_t x_cube = pow(x, 3.0f);
70
+ out[token_idx * d + idx] = w2 * x / divisor +
71
+ w1 * x_square / divisor_square +
72
+ w0 * x_cube / divisor_cube + b;
73
+ }
74
+ }
75
+
76
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
77
+ __global__ void poly_norm_backward_kernel(
78
+ scalar_t* __restrict__ input_grad, // [..., d]
79
+ scalar_t* __restrict__ weight_grad, // [3]
80
+ scalar_t* __restrict__ bias_grad, // [1]
81
+ const scalar_t* __restrict__ output_grad, // [..., d]
82
+ const scalar_t* __restrict__ input, // [..., d]
83
+ const scalar_t* __restrict__ weight, // [3]
84
+ const float eps,
85
+ const int d
86
+ ) {
87
+ const int64_t token_idx = blockIdx.x;
88
+
89
+ acc_t w0 = weight[0];
90
+ acc_t w1 = weight[1];
91
+ acc_t w2 = weight[2];
92
+
93
+ acc_t sum_2 = 0.0f;
94
+ acc_t sum_4 = 0.0f;
95
+ acc_t sum_6 = 0.0f;
96
+
97
+ acc_t sum_dx_1 = 0.0f;
98
+ acc_t sum_dx_2 = 0.0f;
99
+ acc_t sum_dx_3 = 0.0f;
100
+
101
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
102
+ acc_t dy = output_grad[token_idx * d + idx];
103
+
104
+ acc_t x_1 = input[token_idx * d + idx];
105
+ acc_t x_2 = x_1 * x_1;
106
+ acc_t x_3 = x_2 * x_1;
107
+ acc_t x_4 = x_3 * x_1;
108
+ acc_t x_6 = x_4 * x_2;
109
+
110
+ sum_2 += x_2;
111
+ sum_4 += x_4;
112
+ sum_6 += x_6;
113
+
114
+ sum_dx_1 += w2 * dy * x_1;
115
+ sum_dx_2 += w1 * dy * x_2;
116
+ sum_dx_3 += w0 * dy * x_3;
117
+ }
118
+
119
+ __shared__ acc_t shared[BLOCK_SIZE];
120
+
121
+ acc_t mean_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_2, d) / d + eps;
122
+ acc_t mean_4 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_4, d) / d + eps;
123
+ acc_t mean_6 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_6, d) / d + eps;
124
+
125
+ sum_dx_1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_1, d);
126
+ sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
127
+ sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
128
+
129
+ acc_t sq_mean_2 = rsqrtf(mean_2) * mean_2;
130
+ acc_t sq_mean_4 = rsqrtf(mean_4) * mean_4;
131
+ acc_t sq_mean_6 = rsqrtf(mean_6) * mean_6;
132
+
133
+ acc_t denom_2 = mean_2 * sq_mean_2;
134
+ acc_t denom_4 = mean_4 * sq_mean_4;
135
+ acc_t denom_6 = mean_6 * sq_mean_6;
136
+
137
+ acc_t sum_dw0 = 0;
138
+ acc_t sum_dw1 = 0;
139
+ acc_t sum_dw2 = 0;
140
+ acc_t sum_db = 0;
141
+
142
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
143
+ acc_t dy = output_grad[token_idx * d + idx];
144
+ acc_t x_1 = input[token_idx * d + idx];
145
+ acc_t x_2 = x_1 * x_1;
146
+ acc_t x_3 = x_2 * x_1;
147
+
148
+ acc_t _dx_1 = w2 * dy;
149
+ acc_t _dx_2 = w1 * dy;
150
+ acc_t _dx_3 = w0 * dy;
151
+
152
+ acc_t dx_3 =
153
+ 3 * x_2 * (_dx_3 / sq_mean_6 - x_3 * sum_dx_3 / (d * denom_6));
154
+ acc_t dx_2 =
155
+ 2 * x_1 * (_dx_2 / sq_mean_4 - x_2 * sum_dx_2 / (d * denom_4));
156
+ acc_t dx_1 =
157
+ _dx_1 / sq_mean_2 - x_1 * sum_dx_1 / (d * denom_2);
158
+
159
+ if (input_grad) {
160
+ input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
161
+ }
162
+
163
+ sum_dw0 += dy * (x_3 / sq_mean_6);
164
+ sum_dw1 += dy * (x_2 / sq_mean_4);
165
+ sum_dw2 += dy * (x_1 / sq_mean_2);
166
+ sum_db += dy;
167
+ }
168
+
169
+ if (weight_grad) {
170
+ sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
171
+ sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
172
+ sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
173
+
174
+ if (threadIdx.x == 0) {
175
+ atomic_add(&weight_grad[0], sum_dw0);
176
+ atomic_add(&weight_grad[1], sum_dw1);
177
+ atomic_add(&weight_grad[2], sum_dw2);
178
+ }
179
+ }
180
+
181
+ if (bias_grad) {
182
+ sum_db = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_db, d);
183
+ if (threadIdx.x == 0) {
184
+ atomic_add(&bias_grad[0], sum_db);
185
+ }
186
+ }
187
+ }
188
+
189
+ } // namespace motif
190
+
191
+
192
+ void poly_norm(torch::Tensor& out, // [..., d]
193
+ torch::Tensor& input, // [..., d]
194
+ torch::Tensor& weight, // [3]
195
+ torch::Tensor& bias, // [1]
196
+ double eps)
197
+ {
198
+ AssertTensorShapeEqual(input, out, "input", "out");
199
+ AssertTensorNotNull(weight, "weight");
200
+ AssertTensorNotNull(bias, "bias");
201
+ // TODO shape check
202
+
203
+ constexpr int BLOCK_SIZE = 256;
204
+
205
+ int d = input.size(-1);
206
+ int64_t num_tokens = input.numel() / input.size(-1);
207
+ dim3 grid(num_tokens);
208
+ dim3 block(BLOCK_SIZE);
209
+
210
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
211
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
212
+ MOTIF_DISPATCH_FLOATING_TYPES(
213
+ input.scalar_type(), "poly_norm_kernel", [&] {
214
+ motif::poly_norm_kernel<scalar_t, float, BLOCK_SIZE>
215
+ <<<grid, block, 0, stream>>>(
216
+ out.data_ptr<scalar_t>(),
217
+ input.data_ptr<scalar_t>(),
218
+ weight.data_ptr<scalar_t>(),
219
+ bias.data_ptr<scalar_t>(), eps, d);
220
+ }
221
+ );
222
+ }
223
+
224
+ void poly_norm_backward(
225
+ torch::Tensor& input_grad, // [..., d]
226
+ torch::Tensor& weight_grad, // [..., d]
227
+ torch::Tensor& bias_grad, // [..., d]
228
+ torch::Tensor& output_grad, // [3]
229
+ torch::Tensor& input, // [3]
230
+ torch::Tensor& weight, // [3]
231
+ double eps) {
232
+ AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
233
+ AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
234
+ AssertTensorNotNull(weight, "weight");
235
+ // TODO shape check
236
+ // weight_grad, bias_grad and input_grad can be nullable
237
+
238
+ constexpr int BLOCK_SIZE = 256;
239
+
240
+ int d = input.size(-1);
241
+ int64_t num_tokens = input.numel() / input.size(-1);
242
+ dim3 grid(num_tokens);
243
+ dim3 block(BLOCK_SIZE);
244
+
245
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
246
+
247
+ if (weight_grad.defined())
248
+ cudaMemsetAsync(weight_grad.data_ptr(), 0, weight_grad.numel() * weight_grad.element_size(), stream);
249
+ if (bias_grad.defined()) {
250
+ cudaMemsetAsync(bias_grad.data_ptr(), 0, bias_grad.numel() * bias_grad.element_size(), stream);
251
+ }
252
+
253
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
254
+ MOTIF_DISPATCH_FLOATING_TYPES(
255
+ input.scalar_type(), "poly_norm_backward_kernel", [&] {
256
+ motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
257
+ <<<grid, block, 0, stream>>>(
258
+ input_grad.data_ptr<scalar_t>(),
259
+ weight_grad.data_ptr<scalar_t>(),
260
+ bias_grad.data_ptr<scalar_t>(),
261
+ output_grad.data_ptr<scalar_t>(),
262
+ input.data_ptr<scalar_t>(),
263
+ weight.data_ptr<scalar_t>(),
264
+ eps, d);
265
+ }
266
+ );
267
+ }
activation/assert_utils.h ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <ATen/cuda/CUDAContext.h>
4
+ #include <torch/all.h>
5
+
6
+ inline void AssertTensorNotNull(const torch::Tensor &tensor, const std::string &name) {
7
+ TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
8
+ }
9
+
10
+ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::Tensor &tensor_b,
11
+ const std::string &name_a, const std::string &name_b) {
12
+
13
+ AssertTensorNotNull(tensor_a, name_a);
14
+ AssertTensorNotNull(tensor_b, name_b);
15
+
16
+ auto tensor_shape_a = tensor_a.sizes();
17
+ auto tensor_shape_b = tensor_b.sizes();
18
+
19
+ TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
20
+ "{} tensor shape should be equal to {} tensor shape. (actual: {}, expected: {})",
21
+ name_a, name_b, tensor_shape_a, tensor_shape_b);
22
+ }
activation/atomic_utils.h ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <c10/util/BFloat16.h>
5
+ #include <c10/util/Half.h>
6
+
7
+ namespace motif {
8
+ template<typename scalar_t, typename acc_t>
9
+ __device__ inline void atomic_add(scalar_t* address, acc_t value) {
10
+ // TODO: change assert to a static_assert if possible
11
+ assert(false && "Unsupported type for atomic_add");
12
+ }
13
+
14
+ template<>
15
+ __device__ inline void atomic_add<float, float>(float* address, float value) {
16
+ atomicAdd(address, value);
17
+ }
18
+
19
+ template<>
20
+ __device__ inline void atomic_add<double, double>(double* address, double value) {
21
+ atomicAdd(address, value);
22
+ }
23
+
24
+ template<>
25
+ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address, float value) {
26
+ volatile c10::BFloat16* address = const_cast<volatile c10::BFloat16*>(_address);
27
+
28
+ size_t offset = (size_t)address & 0x2;
29
+ volatile uint16_t* address_as_short =
30
+ reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
31
+ volatile uint32_t* address_as_uint =
32
+ reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
33
+ bool is_32bit_aligned = offset == 0;
34
+
35
+ uint32_t current = address_as_uint[0];
36
+ uint32_t expected;
37
+
38
+ do {
39
+ expected = current;
40
+ c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
41
+ c10::BFloat16 next_bf16 = current_bf16 + value;
42
+ uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
43
+ : (current & 0x0000ffff) | (next_bf16.x << 16);
44
+ current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
45
+ } while (current != expected);
46
+ }
47
+
48
+ template<>
49
+ __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float value) {
50
+ volatile c10::Half* address = const_cast<volatile c10::Half*>(_address);
51
+
52
+ size_t offset = (size_t)address & 0x2;
53
+ volatile uint16_t* address_as_short =
54
+ reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
55
+ volatile uint32_t* address_as_uint =
56
+ reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
57
+ bool is_32bit_aligned = offset == 0;
58
+
59
+ uint32_t current = address_as_uint[0];
60
+ uint32_t expected;
61
+
62
+ do {
63
+ expected = current;
64
+ c10::Half current_half(address_as_short[0], c10::Half::from_bits());
65
+ c10::Half next_half = current_half + value;
66
+ uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
67
+ : (current & 0x0000ffff) | (next_half.x << 16);
68
+ current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
69
+ } while (current != expected);
70
+
71
+ }
72
+
73
+ } // namespace motif
activation/cuda_compat.h ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #ifdef USE_ROCM
4
+ #include <hip/hip_runtime.h>
5
+ #endif
6
+
7
+ #ifndef USE_ROCM
8
+ #define WARP_SIZE 32
9
+ #else
10
+ #define WARP_SIZE warpSize
11
+ #endif
12
+
13
+ #ifndef USE_ROCM
14
+ #define VLLM_LDG(arg) __ldg(arg)
15
+ #else
16
+ #define VLLM_LDG(arg) *(arg)
17
+ #endif
18
+
activation/dispatch_utils.h ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Adapted from
3
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
4
+ */
5
+ #pragma once
6
+
7
+ #include <torch/all.h>
8
+
9
+ #define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
+ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
+
14
+ #define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
build.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "activation"
3
+
4
+ [torch]
5
+ src = [
6
+ "torch-ext/torch_binding.cpp",
7
+ "torch-ext/torch_binding.h"
8
+ ]
9
+
10
+ [kernel.activation]
11
+ language = "cuda-hipify"
12
+ rocm-archs = [ "gfx90a" ]
13
+ src = [
14
+ "activation/activation_kernels.cu",
15
+ "activation/cuda_compat.h",
16
+ "activation/dispatch_utils.h",
17
+ "activation/assert_utils.h",
18
+ "activation/atomic_utils.h",
19
+ ]
20
+ depends = [ "torch" ]
flake.lock ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1733328505,
6
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1747919133,
77
+ "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1747925434,
102
+ "narHash": "sha256-yjtdRMyPIFcSF1PkDwU5Rl0bmIpJ5joad5VOt/+1ZLY=",
103
+ "ref": "refs/heads/main",
104
+ "rev": "fd0376ff1fec423c91589075fb9042767558c635",
105
+ "shallow": true,
106
+ "type": "git",
107
+ "url": "file:///home/nixuser/kernel-builder"
108
+ },
109
+ "original": {
110
+ "shallow": true,
111
+ "type": "git",
112
+ "url": "file:///home/nixuser/kernel-builder"
113
+ }
114
+ },
115
+ "nixpkgs": {
116
+ "locked": {
117
+ "lastModified": 1747820358,
118
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
119
+ "owner": "danieldk",
120
+ "repo": "nixpkgs",
121
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
122
+ "type": "github"
123
+ },
124
+ "original": {
125
+ "owner": "danieldk",
126
+ "ref": "cudatoolkit-12.9-kernel-builder",
127
+ "repo": "nixpkgs",
128
+ "type": "github"
129
+ }
130
+ },
131
+ "root": {
132
+ "inputs": {
133
+ "kernel-builder": "kernel-builder"
134
+ }
135
+ },
136
+ "systems": {
137
+ "locked": {
138
+ "lastModified": 1681028828,
139
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
140
+ "owner": "nix-systems",
141
+ "repo": "default",
142
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
143
+ "type": "github"
144
+ },
145
+ "original": {
146
+ "owner": "nix-systems",
147
+ "repo": "default",
148
+ "type": "github"
149
+ }
150
+ },
151
+ "systems_2": {
152
+ "locked": {
153
+ "lastModified": 1681028828,
154
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
+ "owner": "nix-systems",
156
+ "repo": "default",
157
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
+ "type": "github"
159
+ },
160
+ "original": {
161
+ "owner": "nix-systems",
162
+ "repo": "default",
163
+ "type": "github"
164
+ }
165
+ }
166
+ },
167
+ "root": "root",
168
+ "version": 7
169
+ }
flake.nix ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Torch kernel extension";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "/home/nixuser/kernel-builder";
6
+ };
7
+
8
+ outputs = { self, kernel-builder, }:
9
+ kernel-builder.lib.genFlakeOutputs {
10
+ path = ./.;
11
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
12
+ };
13
+ }
tests/__init__.py ADDED
File without changes
tests/kernels/__init__.py ADDED
File without changes
tests/kernels/allclose_default.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Reference default values of atol and rtol are from
4
+ # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5
+ default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6
+ default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
7
+
8
+
9
+ def get_default_atol(output) -> float:
10
+ return default_atol[output.dtype]
11
+
12
+
13
+ def get_default_rtol(output) -> float:
14
+ return default_rtol[output.dtype]
tests/kernels/test_activation.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ import activation
7
+
8
+ from .utils import assert_close, opcheck
9
+
10
+ DTYPES = [torch.float, torch.bfloat16, torch.half]
11
+ # NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
12
+ # D = [512, 13824] # Arbitrary values for testing
13
+ NUM_TOKENS = [7, 13] # Arbitrary values for testing
14
+ D = [513] # Arbitrary values for testing
15
+ SEEDS = [0]
16
+ CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
17
+
18
+
19
+ def norm(x, eps: float) -> torch.Tensor:
20
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
21
+
22
+
23
+ def poly_norm(
24
+ x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
25
+ ) -> torch.Tensor:
26
+ x = x.float()
27
+ return (
28
+ weight[0] * norm(x**3, eps)
29
+ + weight[1] * norm(x**2, eps)
30
+ + weight[2] * norm(x, eps)
31
+ + bias
32
+ ).to(weight.dtype)
33
+
34
+
35
+ @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
36
+ @pytest.mark.parametrize("d", D)
37
+ @pytest.mark.parametrize("dtype", DTYPES)
38
+ @pytest.mark.parametrize("seed", SEEDS)
39
+ @pytest.mark.parametrize("device", CUDA_DEVICES)
40
+ def test_poly_norm(
41
+ num_tokens: int,
42
+ d: int,
43
+ dtype: torch.dtype,
44
+ seed: int,
45
+ device: str,
46
+ ) -> None:
47
+ random.seed(seed)
48
+ torch.manual_seed(seed)
49
+ torch.set_default_device(device)
50
+
51
+ x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
52
+ weight = torch.randn(3, dtype=dtype, requires_grad=True)
53
+ bias = torch.randn(1, dtype=dtype, requires_grad=True)
54
+ eps = 1e-05
55
+
56
+ x.retain_grad()
57
+ weight.retain_grad()
58
+ bias.retain_grad()
59
+ # To separate gradient computation, clone the inputs
60
+
61
+ x_ref = x.detach().clone().requires_grad_(True)
62
+ weight_ref = weight.detach().clone().requires_grad_(True)
63
+ bias_ref = bias.detach().clone().requires_grad_(True)
64
+
65
+ torch_fn = poly_norm
66
+ op = activation.ops.poly_norm
67
+ fn = activation.poly_norm
68
+ layer = activation.layers.PolyNorm(eps)
69
+ layer.weight = torch.nn.Parameter(weight)
70
+ layer.bias = torch.nn.Parameter(bias)
71
+
72
+ out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
73
+ opcheck(op, (out, x, weight, bias, eps))
74
+
75
+ out = fn(x, weight, bias, eps)
76
+ mod_out = layer(x)
77
+ ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
78
+
79
+ assert_close(out, ref_out)
80
+ assert_close(mod_out, out, atol=0.0, rtol=0.0)
81
+
82
+ # test backward pass
83
+ out_grad = torch.randn_like(out)
84
+ out_grad = out_grad / out_grad.norm()
85
+
86
+ ref_out.backward(out_grad)
87
+ mod_out.backward(out_grad)
88
+
89
+ assert_close(x.grad, x_ref.grad)
90
+ assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
91
+ assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
tests/kernels/utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kernel test utils"""
2
+
3
+ import unittest
4
+ from typing import Any, Dict, Optional, Sequence, Tuple, Union
5
+
6
+ import torch
7
+ from torch._prims_common import TensorLikeType
8
+
9
+ from .allclose_default import get_default_atol, get_default_rtol
10
+
11
+ # For now, disable "test_aot_dispatch_dynamic" since there are some
12
+ # bugs related to this test in PyTorch 2.4.
13
+ DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
14
+ "test_schema",
15
+ "test_autograd_registration",
16
+ "test_faketensor",
17
+ )
18
+
19
+ ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
20
+ "test_schema",
21
+ "test_autograd_registration",
22
+ "test_faketensor",
23
+ "test_aot_dispatch_dynamic",
24
+ )
25
+
26
+
27
+ def assert_close(
28
+ a: TensorLikeType,
29
+ b: TensorLikeType,
30
+ atol: float | None = None,
31
+ rtol: float | None = None,
32
+ ) -> None:
33
+ atol = atol if atol is not None else get_default_atol(a)
34
+ rtol = rtol if rtol is not None else get_default_rtol(a)
35
+ torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
36
+
37
+
38
+ # Copied/modified from torch._refs.__init__.py
39
+ def fp8_allclose(
40
+ a: TensorLikeType,
41
+ b: TensorLikeType,
42
+ rtol: float = 1e-05,
43
+ atol: float = 1e-08,
44
+ equal_nan: bool = False,
45
+ ) -> bool:
46
+ """
47
+ Reference implementation of torch.allclose
48
+ """
49
+ torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
50
+
51
+ return bool(
52
+ torch.all(
53
+ torch.isclose(
54
+ a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
55
+ )
56
+ ).item()
57
+ )
58
+
59
+
60
+ # A special version of op check that has a restricted default set of test_utils
61
+ # and a patched version of allclose that supports fp8 types.
62
+ def opcheck(
63
+ op: Union[
64
+ torch._ops.OpOverload,
65
+ torch._ops.OpOverloadPacket,
66
+ torch._library.custom_ops.CustomOpDef,
67
+ ],
68
+ args: Tuple[Any, ...],
69
+ kwargs: Optional[Dict[str, Any]] = None,
70
+ *,
71
+ test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
72
+ raise_exception: bool = True,
73
+ cond: bool = True,
74
+ ) -> Dict[str, str]:
75
+ with unittest.mock.patch("torch.allclose", new=fp8_allclose):
76
+ return (
77
+ torch.library.opcheck(
78
+ op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
79
+ )
80
+ if cond
81
+ else {}
82
+ )
tests/test.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import activation
2
+ import torch
3
+
4
+
5
+ def norm(x, eps: float) -> torch.Tensor:
6
+ return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
7
+
8
+
9
+ def poly_norm(
10
+ x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
11
+ ) -> torch.Tensor:
12
+ x = x.float()
13
+ return (
14
+ weight[0] * norm(x**3, eps)
15
+ + weight[1] * norm(x**2, eps)
16
+ + weight[2] * norm(x, eps)
17
+ + bias
18
+ ).to(weight.dtype)
19
+
20
+
21
+ dtype = torch.bfloat16
22
+ torch.set_default_device("cuda:0")
23
+ a = torch.randn(3, 3, dtype=dtype, requires_grad=True)
24
+ w = torch.randn(3, dtype=dtype, requires_grad=True)
25
+ b = torch.randn(1, dtype=dtype, requires_grad=True)
26
+
27
+ a.retain_grad()
28
+ w.retain_grad()
29
+ b.retain_grad()
30
+
31
+ out = activation.poly_norm(a, w, b, 1e-6)
32
+ # out = poly_norm(a, w, b, 1e-6)
33
+
34
+ out.backward(torch.ones_like(out))
35
+
36
+ print(a.grad)
37
+ print(w.grad)
38
+ print(b.grad)
torch-ext/activation/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from . import layers
4
+ from ._ops import ops
5
+ from .poly_norm import PolyNormFunction
6
+
7
+
8
+ def poly_norm(
9
+ x: torch.Tensor,
10
+ weight: torch.Tensor,
11
+ bias: torch.Tensor,
12
+ eps: float = 1e-6,
13
+ ) -> None:
14
+ return PolyNormFunction.apply(x, weight, bias, eps)
15
+
16
+
17
+ __all__ = [
18
+ "poly_norm",
19
+ "layers",
20
+ "ops",
21
+ ]
torch-ext/activation/layers.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .poly_norm import PolyNormFunction
5
+
6
+
7
+ class PolyNorm(nn.Module):
8
+ def __init__(self, eps):
9
+ super().__init__()
10
+ self.weight = torch.nn.Parameter(torch.ones(3) / 3)
11
+ self.bias = torch.nn.Parameter(torch.zeros(1))
12
+ self.eps = eps
13
+
14
+ def forward(
15
+ self,
16
+ x: torch.Tensor,
17
+ ):
18
+ return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
torch-ext/activation/poly_norm.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ._ops import ops
4
+
5
+
6
+ # Inherit from Function
7
+ class PolyNormFunction(torch.autograd.Function):
8
+ # Note that forward, setup_context, and backward are @staticmethods
9
+ @staticmethod
10
+ def forward(input, weight, bias, eps):
11
+ output = torch.empty_like(input)
12
+ ops.poly_norm(output, input, weight, bias, eps)
13
+ return output
14
+
15
+ @staticmethod
16
+ # inputs is a Tuple of all of the inputs passed to forward.
17
+ # output is the output of the forward().
18
+ def setup_context(ctx, inputs, output):
19
+ input, weight, bias, eps = inputs
20
+ ctx.save_for_backward(input, weight)
21
+ ctx.eps = eps
22
+
23
+ # This function has only a single output, so it gets only one gradient
24
+ @staticmethod
25
+ def backward(ctx, output_grad):
26
+ input, weight = ctx.saved_tensors
27
+ eps = ctx.eps
28
+
29
+ input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
+ weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
+ bias_grad = (
32
+ torch.empty(1, dtype=weight.dtype, device=weight.device)
33
+ if ctx.needs_input_grad[2]
34
+ else None
35
+ )
36
+
37
+ ops.poly_norm_backward(
38
+ input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
+ )
40
+
41
+ return input_grad, weight_grad, bias_grad, None
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+ #include "torch_binding.h"
5
+
6
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
+ // Activation ops
8
+ ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float eps) -> ()");
9
+ ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
10
+ ops.impl("poly_norm", torch::kCUDA, &poly_norm);
11
+ ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
12
+ }
13
+
14
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void poly_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weights, torch::Tensor &bias, double eps);
6
+ void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, torch::Tensor& output_grad, torch::Tensor& input, torch::Tensor& weight, double eps);