Commit
·
44e9845
1
Parent(s):
7a7d761
feat(poly-norm): Add PolyNorm
Browse files- .gitignore +193 -0
- README.md +4 -0
- activation/activation_kernels.cu +267 -0
- activation/assert_utils.h +22 -0
- activation/atomic_utils.h +73 -0
- activation/cuda_compat.h +18 -0
- activation/dispatch_utils.h +15 -0
- build.toml +20 -0
- flake.lock +169 -0
- flake.nix +13 -0
- tests/__init__.py +0 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/allclose_default.py +14 -0
- tests/kernels/test_activation.py +91 -0
- tests/kernels/utils.py +82 -0
- tests/test.py +38 -0
- torch-ext/activation/__init__.py +21 -0
- torch-ext/activation/layers.py +18 -0
- torch-ext/activation/poly_norm.py +41 -0
- torch-ext/torch_binding.cpp +14 -0
- torch-ext/torch_binding.h +6 -0
.gitignore
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by https://www.toptal.com/developers/gitignore/api/vim,python
|
| 2 |
+
# Edit at https://www.toptal.com/developers/gitignore?templates=vim,python
|
| 3 |
+
|
| 4 |
+
### Python ###
|
| 5 |
+
# Byte-compiled / optimized / DLL files
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[cod]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# Distribution / packaging
|
| 11 |
+
.Python
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 110 |
+
.pdm.toml
|
| 111 |
+
|
| 112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 113 |
+
__pypackages__/
|
| 114 |
+
|
| 115 |
+
# Celery stuff
|
| 116 |
+
celerybeat-schedule
|
| 117 |
+
celerybeat.pid
|
| 118 |
+
|
| 119 |
+
# SageMath parsed files
|
| 120 |
+
*.sage.py
|
| 121 |
+
|
| 122 |
+
# Environments
|
| 123 |
+
.env
|
| 124 |
+
.venv
|
| 125 |
+
env/
|
| 126 |
+
venv/
|
| 127 |
+
ENV/
|
| 128 |
+
env.bak/
|
| 129 |
+
venv.bak/
|
| 130 |
+
|
| 131 |
+
# Spyder project settings
|
| 132 |
+
.spyderproject
|
| 133 |
+
.spyproject
|
| 134 |
+
|
| 135 |
+
# Rope project settings
|
| 136 |
+
.ropeproject
|
| 137 |
+
|
| 138 |
+
# mkdocs documentation
|
| 139 |
+
/site
|
| 140 |
+
|
| 141 |
+
# mypy
|
| 142 |
+
.mypy_cache/
|
| 143 |
+
.dmypy.json
|
| 144 |
+
dmypy.json
|
| 145 |
+
|
| 146 |
+
# Pyre type checker
|
| 147 |
+
.pyre/
|
| 148 |
+
|
| 149 |
+
# pytype static type analyzer
|
| 150 |
+
.pytype/
|
| 151 |
+
|
| 152 |
+
# Cython debug symbols
|
| 153 |
+
cython_debug/
|
| 154 |
+
|
| 155 |
+
# PyCharm
|
| 156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 160 |
+
#.idea/
|
| 161 |
+
|
| 162 |
+
### Python Patch ###
|
| 163 |
+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
| 164 |
+
poetry.toml
|
| 165 |
+
|
| 166 |
+
# ruff
|
| 167 |
+
.ruff_cache/
|
| 168 |
+
|
| 169 |
+
# LSP config files
|
| 170 |
+
pyrightconfig.json
|
| 171 |
+
|
| 172 |
+
### Vim ###
|
| 173 |
+
# Swap
|
| 174 |
+
[._]*.s[a-v][a-z]
|
| 175 |
+
!*.svg # comment out if you don't need vector files
|
| 176 |
+
[._]*.sw[a-p]
|
| 177 |
+
[._]s[a-rt-v][a-z]
|
| 178 |
+
[._]ss[a-gi-z]
|
| 179 |
+
[._]sw[a-p]
|
| 180 |
+
|
| 181 |
+
# Session
|
| 182 |
+
Session.vim
|
| 183 |
+
Sessionx.vim
|
| 184 |
+
|
| 185 |
+
# Temporary
|
| 186 |
+
.netrwhist
|
| 187 |
+
*~
|
| 188 |
+
# Auto-generated tag files
|
| 189 |
+
tags
|
| 190 |
+
# Persistent undo
|
| 191 |
+
[._]*.un~
|
| 192 |
+
|
| 193 |
+
# End of https://www.toptal.com/developers/gitignore/api/vim,python
|
README.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- kernel
|
| 4 |
+
---
|
activation/activation_kernels.cu
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 2 |
+
#include <torch/all.h>
|
| 3 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 4 |
+
|
| 5 |
+
#include <cmath>
|
| 6 |
+
|
| 7 |
+
#include "cuda_compat.h"
|
| 8 |
+
#include "dispatch_utils.h"
|
| 9 |
+
#include "assert_utils.h"
|
| 10 |
+
#include "atomic_utils.h"
|
| 11 |
+
|
| 12 |
+
namespace motif {
|
| 13 |
+
|
| 14 |
+
template <typename acc_t, int BLOCK_SIZE>
|
| 15 |
+
__device__ acc_t _block_reduce_sum(volatile acc_t* shared, const float val, const int d) {
|
| 16 |
+
// TODO: Optimize with warp-level primitives
|
| 17 |
+
shared[threadIdx.x] = threadIdx.x < d ? val : 0.0f;
|
| 18 |
+
__syncthreads();
|
| 19 |
+
for (int stride = BLOCK_SIZE / 2; stride > 0; stride /= 2) {
|
| 20 |
+
if (threadIdx.x < stride) {
|
| 21 |
+
shared[threadIdx.x] += shared[threadIdx.x + stride];
|
| 22 |
+
}
|
| 23 |
+
__syncthreads();
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
return shared[0];
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 30 |
+
__global__ void poly_norm_kernel(
|
| 31 |
+
scalar_t* __restrict__ out, // [..., d]
|
| 32 |
+
const scalar_t* __restrict__ input, // [..., d]
|
| 33 |
+
const scalar_t* __restrict__ weight, // [3]
|
| 34 |
+
const scalar_t* __restrict__ bias, // [1]
|
| 35 |
+
const float eps,
|
| 36 |
+
const int d
|
| 37 |
+
) {
|
| 38 |
+
const int64_t token_idx = blockIdx.x;
|
| 39 |
+
|
| 40 |
+
acc_t sum = 0.0f;
|
| 41 |
+
acc_t sum_square = 0.0f;
|
| 42 |
+
acc_t sum_cube = 0.0f;
|
| 43 |
+
|
| 44 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 45 |
+
acc_t x = input[token_idx * d + idx];
|
| 46 |
+
sum += pow(x, 2.0f);
|
| 47 |
+
sum_square += pow(x, 4.0f);
|
| 48 |
+
sum_cube += pow(x, 6.0f);
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
__shared__ acc_t shared[BLOCK_SIZE];
|
| 52 |
+
|
| 53 |
+
acc_t mean = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum, d) / d;
|
| 54 |
+
acc_t mean_square = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
|
| 55 |
+
acc_t mean_cube = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_cube, d) / d;
|
| 56 |
+
|
| 57 |
+
acc_t w0 = weight[0];
|
| 58 |
+
acc_t w1 = weight[1];
|
| 59 |
+
acc_t w2 = weight[2];
|
| 60 |
+
acc_t b = bias[0];
|
| 61 |
+
|
| 62 |
+
acc_t divisor = sqrt(mean + eps);
|
| 63 |
+
acc_t divisor_square = sqrt(mean_square + eps);
|
| 64 |
+
acc_t divisor_cube = sqrt(mean_cube + eps);
|
| 65 |
+
|
| 66 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 67 |
+
acc_t x = input[token_idx * d + idx];
|
| 68 |
+
acc_t x_square = pow(x, 2.0f);
|
| 69 |
+
acc_t x_cube = pow(x, 3.0f);
|
| 70 |
+
out[token_idx * d + idx] = w2 * x / divisor +
|
| 71 |
+
w1 * x_square / divisor_square +
|
| 72 |
+
w0 * x_cube / divisor_cube + b;
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
| 77 |
+
__global__ void poly_norm_backward_kernel(
|
| 78 |
+
scalar_t* __restrict__ input_grad, // [..., d]
|
| 79 |
+
scalar_t* __restrict__ weight_grad, // [3]
|
| 80 |
+
scalar_t* __restrict__ bias_grad, // [1]
|
| 81 |
+
const scalar_t* __restrict__ output_grad, // [..., d]
|
| 82 |
+
const scalar_t* __restrict__ input, // [..., d]
|
| 83 |
+
const scalar_t* __restrict__ weight, // [3]
|
| 84 |
+
const float eps,
|
| 85 |
+
const int d
|
| 86 |
+
) {
|
| 87 |
+
const int64_t token_idx = blockIdx.x;
|
| 88 |
+
|
| 89 |
+
acc_t w0 = weight[0];
|
| 90 |
+
acc_t w1 = weight[1];
|
| 91 |
+
acc_t w2 = weight[2];
|
| 92 |
+
|
| 93 |
+
acc_t sum_2 = 0.0f;
|
| 94 |
+
acc_t sum_4 = 0.0f;
|
| 95 |
+
acc_t sum_6 = 0.0f;
|
| 96 |
+
|
| 97 |
+
acc_t sum_dx_1 = 0.0f;
|
| 98 |
+
acc_t sum_dx_2 = 0.0f;
|
| 99 |
+
acc_t sum_dx_3 = 0.0f;
|
| 100 |
+
|
| 101 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 102 |
+
acc_t dy = output_grad[token_idx * d + idx];
|
| 103 |
+
|
| 104 |
+
acc_t x_1 = input[token_idx * d + idx];
|
| 105 |
+
acc_t x_2 = x_1 * x_1;
|
| 106 |
+
acc_t x_3 = x_2 * x_1;
|
| 107 |
+
acc_t x_4 = x_3 * x_1;
|
| 108 |
+
acc_t x_6 = x_4 * x_2;
|
| 109 |
+
|
| 110 |
+
sum_2 += x_2;
|
| 111 |
+
sum_4 += x_4;
|
| 112 |
+
sum_6 += x_6;
|
| 113 |
+
|
| 114 |
+
sum_dx_1 += w2 * dy * x_1;
|
| 115 |
+
sum_dx_2 += w1 * dy * x_2;
|
| 116 |
+
sum_dx_3 += w0 * dy * x_3;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
__shared__ acc_t shared[BLOCK_SIZE];
|
| 120 |
+
|
| 121 |
+
acc_t mean_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_2, d) / d + eps;
|
| 122 |
+
acc_t mean_4 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_4, d) / d + eps;
|
| 123 |
+
acc_t mean_6 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_6, d) / d + eps;
|
| 124 |
+
|
| 125 |
+
sum_dx_1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_1, d);
|
| 126 |
+
sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
|
| 127 |
+
sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
|
| 128 |
+
|
| 129 |
+
acc_t sq_mean_2 = rsqrtf(mean_2) * mean_2;
|
| 130 |
+
acc_t sq_mean_4 = rsqrtf(mean_4) * mean_4;
|
| 131 |
+
acc_t sq_mean_6 = rsqrtf(mean_6) * mean_6;
|
| 132 |
+
|
| 133 |
+
acc_t denom_2 = mean_2 * sq_mean_2;
|
| 134 |
+
acc_t denom_4 = mean_4 * sq_mean_4;
|
| 135 |
+
acc_t denom_6 = mean_6 * sq_mean_6;
|
| 136 |
+
|
| 137 |
+
acc_t sum_dw0 = 0;
|
| 138 |
+
acc_t sum_dw1 = 0;
|
| 139 |
+
acc_t sum_dw2 = 0;
|
| 140 |
+
acc_t sum_db = 0;
|
| 141 |
+
|
| 142 |
+
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
| 143 |
+
acc_t dy = output_grad[token_idx * d + idx];
|
| 144 |
+
acc_t x_1 = input[token_idx * d + idx];
|
| 145 |
+
acc_t x_2 = x_1 * x_1;
|
| 146 |
+
acc_t x_3 = x_2 * x_1;
|
| 147 |
+
|
| 148 |
+
acc_t _dx_1 = w2 * dy;
|
| 149 |
+
acc_t _dx_2 = w1 * dy;
|
| 150 |
+
acc_t _dx_3 = w0 * dy;
|
| 151 |
+
|
| 152 |
+
acc_t dx_3 =
|
| 153 |
+
3 * x_2 * (_dx_3 / sq_mean_6 - x_3 * sum_dx_3 / (d * denom_6));
|
| 154 |
+
acc_t dx_2 =
|
| 155 |
+
2 * x_1 * (_dx_2 / sq_mean_4 - x_2 * sum_dx_2 / (d * denom_4));
|
| 156 |
+
acc_t dx_1 =
|
| 157 |
+
_dx_1 / sq_mean_2 - x_1 * sum_dx_1 / (d * denom_2);
|
| 158 |
+
|
| 159 |
+
if (input_grad) {
|
| 160 |
+
input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
sum_dw0 += dy * (x_3 / sq_mean_6);
|
| 164 |
+
sum_dw1 += dy * (x_2 / sq_mean_4);
|
| 165 |
+
sum_dw2 += dy * (x_1 / sq_mean_2);
|
| 166 |
+
sum_db += dy;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
if (weight_grad) {
|
| 170 |
+
sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
|
| 171 |
+
sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
|
| 172 |
+
sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
|
| 173 |
+
|
| 174 |
+
if (threadIdx.x == 0) {
|
| 175 |
+
atomic_add(&weight_grad[0], sum_dw0);
|
| 176 |
+
atomic_add(&weight_grad[1], sum_dw1);
|
| 177 |
+
atomic_add(&weight_grad[2], sum_dw2);
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if (bias_grad) {
|
| 182 |
+
sum_db = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_db, d);
|
| 183 |
+
if (threadIdx.x == 0) {
|
| 184 |
+
atomic_add(&bias_grad[0], sum_db);
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
} // namespace motif
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
void poly_norm(torch::Tensor& out, // [..., d]
|
| 193 |
+
torch::Tensor& input, // [..., d]
|
| 194 |
+
torch::Tensor& weight, // [3]
|
| 195 |
+
torch::Tensor& bias, // [1]
|
| 196 |
+
double eps)
|
| 197 |
+
{
|
| 198 |
+
AssertTensorShapeEqual(input, out, "input", "out");
|
| 199 |
+
AssertTensorNotNull(weight, "weight");
|
| 200 |
+
AssertTensorNotNull(bias, "bias");
|
| 201 |
+
// TODO shape check
|
| 202 |
+
|
| 203 |
+
constexpr int BLOCK_SIZE = 256;
|
| 204 |
+
|
| 205 |
+
int d = input.size(-1);
|
| 206 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
| 207 |
+
dim3 grid(num_tokens);
|
| 208 |
+
dim3 block(BLOCK_SIZE);
|
| 209 |
+
|
| 210 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 211 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 212 |
+
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 213 |
+
input.scalar_type(), "poly_norm_kernel", [&] {
|
| 214 |
+
motif::poly_norm_kernel<scalar_t, float, BLOCK_SIZE>
|
| 215 |
+
<<<grid, block, 0, stream>>>(
|
| 216 |
+
out.data_ptr<scalar_t>(),
|
| 217 |
+
input.data_ptr<scalar_t>(),
|
| 218 |
+
weight.data_ptr<scalar_t>(),
|
| 219 |
+
bias.data_ptr<scalar_t>(), eps, d);
|
| 220 |
+
}
|
| 221 |
+
);
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
void poly_norm_backward(
|
| 225 |
+
torch::Tensor& input_grad, // [..., d]
|
| 226 |
+
torch::Tensor& weight_grad, // [..., d]
|
| 227 |
+
torch::Tensor& bias_grad, // [..., d]
|
| 228 |
+
torch::Tensor& output_grad, // [3]
|
| 229 |
+
torch::Tensor& input, // [3]
|
| 230 |
+
torch::Tensor& weight, // [3]
|
| 231 |
+
double eps) {
|
| 232 |
+
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
| 233 |
+
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
| 234 |
+
AssertTensorNotNull(weight, "weight");
|
| 235 |
+
// TODO shape check
|
| 236 |
+
// weight_grad, bias_grad and input_grad can be nullable
|
| 237 |
+
|
| 238 |
+
constexpr int BLOCK_SIZE = 256;
|
| 239 |
+
|
| 240 |
+
int d = input.size(-1);
|
| 241 |
+
int64_t num_tokens = input.numel() / input.size(-1);
|
| 242 |
+
dim3 grid(num_tokens);
|
| 243 |
+
dim3 block(BLOCK_SIZE);
|
| 244 |
+
|
| 245 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
| 246 |
+
|
| 247 |
+
if (weight_grad.defined())
|
| 248 |
+
cudaMemsetAsync(weight_grad.data_ptr(), 0, weight_grad.numel() * weight_grad.element_size(), stream);
|
| 249 |
+
if (bias_grad.defined()) {
|
| 250 |
+
cudaMemsetAsync(bias_grad.data_ptr(), 0, bias_grad.numel() * bias_grad.element_size(), stream);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
| 254 |
+
MOTIF_DISPATCH_FLOATING_TYPES(
|
| 255 |
+
input.scalar_type(), "poly_norm_backward_kernel", [&] {
|
| 256 |
+
motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
|
| 257 |
+
<<<grid, block, 0, stream>>>(
|
| 258 |
+
input_grad.data_ptr<scalar_t>(),
|
| 259 |
+
weight_grad.data_ptr<scalar_t>(),
|
| 260 |
+
bias_grad.data_ptr<scalar_t>(),
|
| 261 |
+
output_grad.data_ptr<scalar_t>(),
|
| 262 |
+
input.data_ptr<scalar_t>(),
|
| 263 |
+
weight.data_ptr<scalar_t>(),
|
| 264 |
+
eps, d);
|
| 265 |
+
}
|
| 266 |
+
);
|
| 267 |
+
}
|
activation/assert_utils.h
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <ATen/cuda/CUDAContext.h>
|
| 4 |
+
#include <torch/all.h>
|
| 5 |
+
|
| 6 |
+
inline void AssertTensorNotNull(const torch::Tensor &tensor, const std::string &name) {
|
| 7 |
+
TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::Tensor &tensor_b,
|
| 11 |
+
const std::string &name_a, const std::string &name_b) {
|
| 12 |
+
|
| 13 |
+
AssertTensorNotNull(tensor_a, name_a);
|
| 14 |
+
AssertTensorNotNull(tensor_b, name_b);
|
| 15 |
+
|
| 16 |
+
auto tensor_shape_a = tensor_a.sizes();
|
| 17 |
+
auto tensor_shape_b = tensor_b.sizes();
|
| 18 |
+
|
| 19 |
+
TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
|
| 20 |
+
"{} tensor shape should be equal to {} tensor shape. (actual: {}, expected: {})",
|
| 21 |
+
name_a, name_b, tensor_shape_a, tensor_shape_b);
|
| 22 |
+
}
|
activation/atomic_utils.h
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <cuda.h>
|
| 4 |
+
#include <c10/util/BFloat16.h>
|
| 5 |
+
#include <c10/util/Half.h>
|
| 6 |
+
|
| 7 |
+
namespace motif {
|
| 8 |
+
template<typename scalar_t, typename acc_t>
|
| 9 |
+
__device__ inline void atomic_add(scalar_t* address, acc_t value) {
|
| 10 |
+
// TODO: change assert to a static_assert if possible
|
| 11 |
+
assert(false && "Unsupported type for atomic_add");
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
template<>
|
| 15 |
+
__device__ inline void atomic_add<float, float>(float* address, float value) {
|
| 16 |
+
atomicAdd(address, value);
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
template<>
|
| 20 |
+
__device__ inline void atomic_add<double, double>(double* address, double value) {
|
| 21 |
+
atomicAdd(address, value);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
template<>
|
| 25 |
+
__device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address, float value) {
|
| 26 |
+
volatile c10::BFloat16* address = const_cast<volatile c10::BFloat16*>(_address);
|
| 27 |
+
|
| 28 |
+
size_t offset = (size_t)address & 0x2;
|
| 29 |
+
volatile uint16_t* address_as_short =
|
| 30 |
+
reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
|
| 31 |
+
volatile uint32_t* address_as_uint =
|
| 32 |
+
reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
|
| 33 |
+
bool is_32bit_aligned = offset == 0;
|
| 34 |
+
|
| 35 |
+
uint32_t current = address_as_uint[0];
|
| 36 |
+
uint32_t expected;
|
| 37 |
+
|
| 38 |
+
do {
|
| 39 |
+
expected = current;
|
| 40 |
+
c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
|
| 41 |
+
c10::BFloat16 next_bf16 = current_bf16 + value;
|
| 42 |
+
uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
|
| 43 |
+
: (current & 0x0000ffff) | (next_bf16.x << 16);
|
| 44 |
+
current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
|
| 45 |
+
} while (current != expected);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template<>
|
| 49 |
+
__device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float value) {
|
| 50 |
+
volatile c10::Half* address = const_cast<volatile c10::Half*>(_address);
|
| 51 |
+
|
| 52 |
+
size_t offset = (size_t)address & 0x2;
|
| 53 |
+
volatile uint16_t* address_as_short =
|
| 54 |
+
reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
|
| 55 |
+
volatile uint32_t* address_as_uint =
|
| 56 |
+
reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
|
| 57 |
+
bool is_32bit_aligned = offset == 0;
|
| 58 |
+
|
| 59 |
+
uint32_t current = address_as_uint[0];
|
| 60 |
+
uint32_t expected;
|
| 61 |
+
|
| 62 |
+
do {
|
| 63 |
+
expected = current;
|
| 64 |
+
c10::Half current_half(address_as_short[0], c10::Half::from_bits());
|
| 65 |
+
c10::Half next_half = current_half + value;
|
| 66 |
+
uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
|
| 67 |
+
: (current & 0x0000ffff) | (next_half.x << 16);
|
| 68 |
+
current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
|
| 69 |
+
} while (current != expected);
|
| 70 |
+
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
} // namespace motif
|
activation/cuda_compat.h
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#ifdef USE_ROCM
|
| 4 |
+
#include <hip/hip_runtime.h>
|
| 5 |
+
#endif
|
| 6 |
+
|
| 7 |
+
#ifndef USE_ROCM
|
| 8 |
+
#define WARP_SIZE 32
|
| 9 |
+
#else
|
| 10 |
+
#define WARP_SIZE warpSize
|
| 11 |
+
#endif
|
| 12 |
+
|
| 13 |
+
#ifndef USE_ROCM
|
| 14 |
+
#define VLLM_LDG(arg) __ldg(arg)
|
| 15 |
+
#else
|
| 16 |
+
#define VLLM_LDG(arg) *(arg)
|
| 17 |
+
#endif
|
| 18 |
+
|
activation/dispatch_utils.h
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Adapted from
|
| 3 |
+
* https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
|
| 4 |
+
*/
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#include <torch/all.h>
|
| 8 |
+
|
| 9 |
+
#define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
|
| 10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
| 11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
| 12 |
+
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
| 13 |
+
|
| 14 |
+
#define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
| 15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
build.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "activation"
|
| 3 |
+
|
| 4 |
+
[torch]
|
| 5 |
+
src = [
|
| 6 |
+
"torch-ext/torch_binding.cpp",
|
| 7 |
+
"torch-ext/torch_binding.h"
|
| 8 |
+
]
|
| 9 |
+
|
| 10 |
+
[kernel.activation]
|
| 11 |
+
language = "cuda-hipify"
|
| 12 |
+
rocm-archs = [ "gfx90a" ]
|
| 13 |
+
src = [
|
| 14 |
+
"activation/activation_kernels.cu",
|
| 15 |
+
"activation/cuda_compat.h",
|
| 16 |
+
"activation/dispatch_utils.h",
|
| 17 |
+
"activation/assert_utils.h",
|
| 18 |
+
"activation/atomic_utils.h",
|
| 19 |
+
]
|
| 20 |
+
depends = [ "torch" ]
|
flake.lock
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1733328505,
|
| 6 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1733328505,
|
| 21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1747919133,
|
| 77 |
+
"narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1747925434,
|
| 102 |
+
"narHash": "sha256-yjtdRMyPIFcSF1PkDwU5Rl0bmIpJ5joad5VOt/+1ZLY=",
|
| 103 |
+
"ref": "refs/heads/main",
|
| 104 |
+
"rev": "fd0376ff1fec423c91589075fb9042767558c635",
|
| 105 |
+
"shallow": true,
|
| 106 |
+
"type": "git",
|
| 107 |
+
"url": "file:///home/nixuser/kernel-builder"
|
| 108 |
+
},
|
| 109 |
+
"original": {
|
| 110 |
+
"shallow": true,
|
| 111 |
+
"type": "git",
|
| 112 |
+
"url": "file:///home/nixuser/kernel-builder"
|
| 113 |
+
}
|
| 114 |
+
},
|
| 115 |
+
"nixpkgs": {
|
| 116 |
+
"locked": {
|
| 117 |
+
"lastModified": 1747820358,
|
| 118 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
| 119 |
+
"owner": "danieldk",
|
| 120 |
+
"repo": "nixpkgs",
|
| 121 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
| 122 |
+
"type": "github"
|
| 123 |
+
},
|
| 124 |
+
"original": {
|
| 125 |
+
"owner": "danieldk",
|
| 126 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
| 127 |
+
"repo": "nixpkgs",
|
| 128 |
+
"type": "github"
|
| 129 |
+
}
|
| 130 |
+
},
|
| 131 |
+
"root": {
|
| 132 |
+
"inputs": {
|
| 133 |
+
"kernel-builder": "kernel-builder"
|
| 134 |
+
}
|
| 135 |
+
},
|
| 136 |
+
"systems": {
|
| 137 |
+
"locked": {
|
| 138 |
+
"lastModified": 1681028828,
|
| 139 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 140 |
+
"owner": "nix-systems",
|
| 141 |
+
"repo": "default",
|
| 142 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 143 |
+
"type": "github"
|
| 144 |
+
},
|
| 145 |
+
"original": {
|
| 146 |
+
"owner": "nix-systems",
|
| 147 |
+
"repo": "default",
|
| 148 |
+
"type": "github"
|
| 149 |
+
}
|
| 150 |
+
},
|
| 151 |
+
"systems_2": {
|
| 152 |
+
"locked": {
|
| 153 |
+
"lastModified": 1681028828,
|
| 154 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 155 |
+
"owner": "nix-systems",
|
| 156 |
+
"repo": "default",
|
| 157 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 158 |
+
"type": "github"
|
| 159 |
+
},
|
| 160 |
+
"original": {
|
| 161 |
+
"owner": "nix-systems",
|
| 162 |
+
"repo": "default",
|
| 163 |
+
"type": "github"
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
},
|
| 167 |
+
"root": "root",
|
| 168 |
+
"version": 7
|
| 169 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for Torch kernel extension";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "/home/nixuser/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs = { self, kernel-builder, }:
|
| 9 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 10 |
+
path = ./.;
|
| 11 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 12 |
+
};
|
| 13 |
+
}
|
tests/__init__.py
ADDED
|
File without changes
|
tests/kernels/__init__.py
ADDED
|
File without changes
|
tests/kernels/allclose_default.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Reference default values of atol and rtol are from
|
| 4 |
+
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
| 5 |
+
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
| 6 |
+
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_default_atol(output) -> float:
|
| 10 |
+
return default_atol[output.dtype]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_default_rtol(output) -> float:
|
| 14 |
+
return default_rtol[output.dtype]
|
tests/kernels/test_activation.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
import activation
|
| 7 |
+
|
| 8 |
+
from .utils import assert_close, opcheck
|
| 9 |
+
|
| 10 |
+
DTYPES = [torch.float, torch.bfloat16, torch.half]
|
| 11 |
+
# NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
| 12 |
+
# D = [512, 13824] # Arbitrary values for testing
|
| 13 |
+
NUM_TOKENS = [7, 13] # Arbitrary values for testing
|
| 14 |
+
D = [513] # Arbitrary values for testing
|
| 15 |
+
SEEDS = [0]
|
| 16 |
+
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def norm(x, eps: float) -> torch.Tensor:
|
| 20 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def poly_norm(
|
| 24 |
+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
|
| 25 |
+
) -> torch.Tensor:
|
| 26 |
+
x = x.float()
|
| 27 |
+
return (
|
| 28 |
+
weight[0] * norm(x**3, eps)
|
| 29 |
+
+ weight[1] * norm(x**2, eps)
|
| 30 |
+
+ weight[2] * norm(x, eps)
|
| 31 |
+
+ bias
|
| 32 |
+
).to(weight.dtype)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
| 36 |
+
@pytest.mark.parametrize("d", D)
|
| 37 |
+
@pytest.mark.parametrize("dtype", DTYPES)
|
| 38 |
+
@pytest.mark.parametrize("seed", SEEDS)
|
| 39 |
+
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
| 40 |
+
def test_poly_norm(
|
| 41 |
+
num_tokens: int,
|
| 42 |
+
d: int,
|
| 43 |
+
dtype: torch.dtype,
|
| 44 |
+
seed: int,
|
| 45 |
+
device: str,
|
| 46 |
+
) -> None:
|
| 47 |
+
random.seed(seed)
|
| 48 |
+
torch.manual_seed(seed)
|
| 49 |
+
torch.set_default_device(device)
|
| 50 |
+
|
| 51 |
+
x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True)
|
| 52 |
+
weight = torch.randn(3, dtype=dtype, requires_grad=True)
|
| 53 |
+
bias = torch.randn(1, dtype=dtype, requires_grad=True)
|
| 54 |
+
eps = 1e-05
|
| 55 |
+
|
| 56 |
+
x.retain_grad()
|
| 57 |
+
weight.retain_grad()
|
| 58 |
+
bias.retain_grad()
|
| 59 |
+
# To separate gradient computation, clone the inputs
|
| 60 |
+
|
| 61 |
+
x_ref = x.detach().clone().requires_grad_(True)
|
| 62 |
+
weight_ref = weight.detach().clone().requires_grad_(True)
|
| 63 |
+
bias_ref = bias.detach().clone().requires_grad_(True)
|
| 64 |
+
|
| 65 |
+
torch_fn = poly_norm
|
| 66 |
+
op = activation.ops.poly_norm
|
| 67 |
+
fn = activation.poly_norm
|
| 68 |
+
layer = activation.layers.PolyNorm(eps)
|
| 69 |
+
layer.weight = torch.nn.Parameter(weight)
|
| 70 |
+
layer.bias = torch.nn.Parameter(bias)
|
| 71 |
+
|
| 72 |
+
out = torch.empty(x.shape, dtype=x.dtype, device=x.device)
|
| 73 |
+
opcheck(op, (out, x, weight, bias, eps))
|
| 74 |
+
|
| 75 |
+
out = fn(x, weight, bias, eps)
|
| 76 |
+
mod_out = layer(x)
|
| 77 |
+
ref_out = torch_fn(x_ref, weight_ref, bias_ref, eps)
|
| 78 |
+
|
| 79 |
+
assert_close(out, ref_out)
|
| 80 |
+
assert_close(mod_out, out, atol=0.0, rtol=0.0)
|
| 81 |
+
|
| 82 |
+
# test backward pass
|
| 83 |
+
out_grad = torch.randn_like(out)
|
| 84 |
+
out_grad = out_grad / out_grad.norm()
|
| 85 |
+
|
| 86 |
+
ref_out.backward(out_grad)
|
| 87 |
+
mod_out.backward(out_grad)
|
| 88 |
+
|
| 89 |
+
assert_close(x.grad, x_ref.grad)
|
| 90 |
+
assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05)
|
| 91 |
+
assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05)
|
tests/kernels/utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Kernel test utils"""
|
| 2 |
+
|
| 3 |
+
import unittest
|
| 4 |
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch._prims_common import TensorLikeType
|
| 8 |
+
|
| 9 |
+
from .allclose_default import get_default_atol, get_default_rtol
|
| 10 |
+
|
| 11 |
+
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
| 12 |
+
# bugs related to this test in PyTorch 2.4.
|
| 13 |
+
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
| 14 |
+
"test_schema",
|
| 15 |
+
"test_autograd_registration",
|
| 16 |
+
"test_faketensor",
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
| 20 |
+
"test_schema",
|
| 21 |
+
"test_autograd_registration",
|
| 22 |
+
"test_faketensor",
|
| 23 |
+
"test_aot_dispatch_dynamic",
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def assert_close(
|
| 28 |
+
a: TensorLikeType,
|
| 29 |
+
b: TensorLikeType,
|
| 30 |
+
atol: float | None = None,
|
| 31 |
+
rtol: float | None = None,
|
| 32 |
+
) -> None:
|
| 33 |
+
atol = atol if atol is not None else get_default_atol(a)
|
| 34 |
+
rtol = rtol if rtol is not None else get_default_rtol(a)
|
| 35 |
+
torch.testing.assert_close(a, b, atol=atol, rtol=rtol)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Copied/modified from torch._refs.__init__.py
|
| 39 |
+
def fp8_allclose(
|
| 40 |
+
a: TensorLikeType,
|
| 41 |
+
b: TensorLikeType,
|
| 42 |
+
rtol: float = 1e-05,
|
| 43 |
+
atol: float = 1e-08,
|
| 44 |
+
equal_nan: bool = False,
|
| 45 |
+
) -> bool:
|
| 46 |
+
"""
|
| 47 |
+
Reference implementation of torch.allclose
|
| 48 |
+
"""
|
| 49 |
+
torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
|
| 50 |
+
|
| 51 |
+
return bool(
|
| 52 |
+
torch.all(
|
| 53 |
+
torch.isclose(
|
| 54 |
+
a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
|
| 55 |
+
)
|
| 56 |
+
).item()
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# A special version of op check that has a restricted default set of test_utils
|
| 61 |
+
# and a patched version of allclose that supports fp8 types.
|
| 62 |
+
def opcheck(
|
| 63 |
+
op: Union[
|
| 64 |
+
torch._ops.OpOverload,
|
| 65 |
+
torch._ops.OpOverloadPacket,
|
| 66 |
+
torch._library.custom_ops.CustomOpDef,
|
| 67 |
+
],
|
| 68 |
+
args: Tuple[Any, ...],
|
| 69 |
+
kwargs: Optional[Dict[str, Any]] = None,
|
| 70 |
+
*,
|
| 71 |
+
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
| 72 |
+
raise_exception: bool = True,
|
| 73 |
+
cond: bool = True,
|
| 74 |
+
) -> Dict[str, str]:
|
| 75 |
+
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
| 76 |
+
return (
|
| 77 |
+
torch.library.opcheck(
|
| 78 |
+
op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
|
| 79 |
+
)
|
| 80 |
+
if cond
|
| 81 |
+
else {}
|
| 82 |
+
)
|
tests/test.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import activation
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def norm(x, eps: float) -> torch.Tensor:
|
| 6 |
+
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def poly_norm(
|
| 10 |
+
x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
|
| 11 |
+
) -> torch.Tensor:
|
| 12 |
+
x = x.float()
|
| 13 |
+
return (
|
| 14 |
+
weight[0] * norm(x**3, eps)
|
| 15 |
+
+ weight[1] * norm(x**2, eps)
|
| 16 |
+
+ weight[2] * norm(x, eps)
|
| 17 |
+
+ bias
|
| 18 |
+
).to(weight.dtype)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
dtype = torch.bfloat16
|
| 22 |
+
torch.set_default_device("cuda:0")
|
| 23 |
+
a = torch.randn(3, 3, dtype=dtype, requires_grad=True)
|
| 24 |
+
w = torch.randn(3, dtype=dtype, requires_grad=True)
|
| 25 |
+
b = torch.randn(1, dtype=dtype, requires_grad=True)
|
| 26 |
+
|
| 27 |
+
a.retain_grad()
|
| 28 |
+
w.retain_grad()
|
| 29 |
+
b.retain_grad()
|
| 30 |
+
|
| 31 |
+
out = activation.poly_norm(a, w, b, 1e-6)
|
| 32 |
+
# out = poly_norm(a, w, b, 1e-6)
|
| 33 |
+
|
| 34 |
+
out.backward(torch.ones_like(out))
|
| 35 |
+
|
| 36 |
+
print(a.grad)
|
| 37 |
+
print(w.grad)
|
| 38 |
+
print(b.grad)
|
torch-ext/activation/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from . import layers
|
| 4 |
+
from ._ops import ops
|
| 5 |
+
from .poly_norm import PolyNormFunction
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def poly_norm(
|
| 9 |
+
x: torch.Tensor,
|
| 10 |
+
weight: torch.Tensor,
|
| 11 |
+
bias: torch.Tensor,
|
| 12 |
+
eps: float = 1e-6,
|
| 13 |
+
) -> None:
|
| 14 |
+
return PolyNormFunction.apply(x, weight, bias, eps)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"poly_norm",
|
| 19 |
+
"layers",
|
| 20 |
+
"ops",
|
| 21 |
+
]
|
torch-ext/activation/layers.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .poly_norm import PolyNormFunction
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PolyNorm(nn.Module):
|
| 8 |
+
def __init__(self, eps):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
|
| 11 |
+
self.bias = torch.nn.Parameter(torch.zeros(1))
|
| 12 |
+
self.eps = eps
|
| 13 |
+
|
| 14 |
+
def forward(
|
| 15 |
+
self,
|
| 16 |
+
x: torch.Tensor,
|
| 17 |
+
):
|
| 18 |
+
return PolyNormFunction.apply(x, self.weight, self.bias, self.eps)
|
torch-ext/activation/poly_norm.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from ._ops import ops
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
# Inherit from Function
|
| 7 |
+
class PolyNormFunction(torch.autograd.Function):
|
| 8 |
+
# Note that forward, setup_context, and backward are @staticmethods
|
| 9 |
+
@staticmethod
|
| 10 |
+
def forward(input, weight, bias, eps):
|
| 11 |
+
output = torch.empty_like(input)
|
| 12 |
+
ops.poly_norm(output, input, weight, bias, eps)
|
| 13 |
+
return output
|
| 14 |
+
|
| 15 |
+
@staticmethod
|
| 16 |
+
# inputs is a Tuple of all of the inputs passed to forward.
|
| 17 |
+
# output is the output of the forward().
|
| 18 |
+
def setup_context(ctx, inputs, output):
|
| 19 |
+
input, weight, bias, eps = inputs
|
| 20 |
+
ctx.save_for_backward(input, weight)
|
| 21 |
+
ctx.eps = eps
|
| 22 |
+
|
| 23 |
+
# This function has only a single output, so it gets only one gradient
|
| 24 |
+
@staticmethod
|
| 25 |
+
def backward(ctx, output_grad):
|
| 26 |
+
input, weight = ctx.saved_tensors
|
| 27 |
+
eps = ctx.eps
|
| 28 |
+
|
| 29 |
+
input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
|
| 30 |
+
weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
|
| 31 |
+
bias_grad = (
|
| 32 |
+
torch.empty(1, dtype=weight.dtype, device=weight.device)
|
| 33 |
+
if ctx.needs_input_grad[2]
|
| 34 |
+
else None
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
ops.poly_norm_backward(
|
| 38 |
+
input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return input_grad, weight_grad, bias_grad, None
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
#include "torch_binding.h"
|
| 5 |
+
|
| 6 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 7 |
+
// Activation ops
|
| 8 |
+
ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float eps) -> ()");
|
| 9 |
+
ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
|
| 10 |
+
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
| 11 |
+
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void poly_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weights, torch::Tensor &bias, double eps);
|
| 6 |
+
void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, torch::Tensor& output_grad, torch::Tensor& input, torch::Tensor& weight, double eps);
|