Commit
·
f517c97
1
Parent(s):
a73a0c0
add readme with precommit hooks and applied pre commit to all files
Browse files- README.md +45 -0
- activation/assert_utils.h +9 -5
- activation/atomic_utils.h +38 -31
- activation/block_reduce.h +3 -2
- activation/dispatch_utils.h +6 -5
- activation/rms_norm.cu +42 -59
- tests/conftest.py +16 -16
- tests/kernels/allclose_default.py +5 -1
- tests/kernels/test_poly_norm.py +7 -10
- tests/kernels/test_poly_norm_perf.py +8 -7
- tests/kernels/test_rms_norm.py +3 -1
- tests/kernels/utils.py +16 -13
- torch-ext/activation/layers.py +2 -0
- torch-ext/activation/poly_norm.py +9 -11
- torch-ext/activation/rms_norm.py +6 -3
- torch-ext/torch_binding.cpp +9 -6
- torch-ext/torch_binding.h +14 -4
README.md
CHANGED
@@ -32,6 +32,7 @@ print(poly_norm(x))
|
|
32 |
|
33 |
- Test cases are from the Motif LLM
|
34 |
- You can reproduce the results with:
|
|
|
35 |
```bash
|
36 |
cd tests
|
37 |
pytest --run-perf --do-plot
|
@@ -39,3 +40,47 @@ pytest --run-perf --do-plot
|
|
39 |
|
40 |

|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
- Test cases are from the Motif LLM
|
34 |
- You can reproduce the results with:
|
35 |
+
|
36 |
```bash
|
37 |
cd tests
|
38 |
pytest --run-perf --do-plot
|
|
|
40 |
|
41 |

|
42 |
|
43 |
+
## Pre-commit Hooks
|
44 |
+
|
45 |
+
This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
|
46 |
+
|
47 |
+
### Setup
|
48 |
+
|
49 |
+
1. Install pre-commit:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
pip install pre-commit
|
53 |
+
```
|
54 |
+
|
55 |
+
2. Install the git hooks:
|
56 |
+
|
57 |
+
```bash
|
58 |
+
pre-commit install
|
59 |
+
```
|
60 |
+
|
61 |
+
Once installed, the configured hooks will run automatically on each commit.
|
62 |
+
|
63 |
+
### Included Hooks
|
64 |
+
|
65 |
+
The following tools are run via pre-commit:
|
66 |
+
|
67 |
+
- **[yapf](https://github.com/google/yapf)** – Python code formatter
|
68 |
+
- **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
|
69 |
+
- **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
|
70 |
+
- **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
|
71 |
+
- **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
|
72 |
+
- **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
|
73 |
+
|
74 |
+
### Usage
|
75 |
+
|
76 |
+
- Run all checks on the entire codebase:
|
77 |
+
|
78 |
+
```bash
|
79 |
+
pre-commit run --all-files
|
80 |
+
```
|
81 |
+
|
82 |
+
- Run a specific hook (example: isort):
|
83 |
+
|
84 |
+
```bash
|
85 |
+
pre-commit run isort --all-files
|
86 |
+
```
|
activation/assert_utils.h
CHANGED
@@ -3,12 +3,15 @@
|
|
3 |
#include <ATen/cuda/CUDAContext.h>
|
4 |
#include <torch/all.h>
|
5 |
|
6 |
-
inline void AssertTensorNotNull(const torch::Tensor &tensor,
|
|
|
7 |
TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
|
8 |
}
|
9 |
|
10 |
-
inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
|
11 |
-
|
|
|
|
|
12 |
|
13 |
AssertTensorNotNull(tensor_a, name_a);
|
14 |
AssertTensorNotNull(tensor_b, name_b);
|
@@ -17,6 +20,7 @@ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::T
|
|
17 |
auto tensor_shape_b = tensor_b.sizes();
|
18 |
|
19 |
TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
|
20 |
-
|
21 |
-
|
|
|
22 |
}
|
|
|
3 |
#include <ATen/cuda/CUDAContext.h>
|
4 |
#include <torch/all.h>
|
5 |
|
6 |
+
inline void AssertTensorNotNull(const torch::Tensor &tensor,
|
7 |
+
const std::string &name) {
|
8 |
TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
|
9 |
}
|
10 |
|
11 |
+
inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
|
12 |
+
const torch::Tensor &tensor_b,
|
13 |
+
const std::string &name_a,
|
14 |
+
const std::string &name_b) {
|
15 |
|
16 |
AssertTensorNotNull(tensor_a, name_a);
|
17 |
AssertTensorNotNull(tensor_b, name_b);
|
|
|
20 |
auto tensor_shape_b = tensor_b.sizes();
|
21 |
|
22 |
TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
|
23 |
+
"{} tensor shape should be equal to {} tensor shape. "
|
24 |
+
"(actual: {}, expected: {})",
|
25 |
+
name_a, name_b, tensor_shape_a, tensor_shape_b);
|
26 |
}
|
activation/atomic_utils.h
CHANGED
@@ -1,35 +1,38 @@
|
|
1 |
#pragma once
|
2 |
|
3 |
-
#include <cuda.h>
|
4 |
#include <c10/util/BFloat16.h>
|
5 |
#include <c10/util/Half.h>
|
|
|
6 |
|
7 |
namespace motif {
|
8 |
-
template<typename scalar_t, typename acc_t>
|
9 |
-
__device__ inline void atomic_add(scalar_t*
|
10 |
// TODO: change assert to a static_assert if possible
|
11 |
-
|
12 |
}
|
13 |
|
14 |
-
template<>
|
15 |
-
__device__ inline void atomic_add<float, float>(float*
|
16 |
-
|
17 |
}
|
18 |
|
19 |
-
template<>
|
20 |
-
__device__ inline void atomic_add<double, double>(double*
|
21 |
-
|
|
|
22 |
}
|
23 |
|
24 |
-
template<>
|
25 |
-
__device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16*
|
26 |
-
|
|
|
|
|
27 |
|
28 |
size_t offset = (size_t)address & 0x2;
|
29 |
-
volatile uint16_t*
|
30 |
-
reinterpret_cast<volatile
|
31 |
-
volatile uint32_t*
|
32 |
-
reinterpret_cast<volatile
|
33 |
bool is_32bit_aligned = offset == 0;
|
34 |
|
35 |
uint32_t current = address_as_uint[0];
|
@@ -39,21 +42,24 @@ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address,
|
|
39 |
expected = current;
|
40 |
c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
|
41 |
c10::BFloat16 next_bf16 = current_bf16 + value;
|
42 |
-
uint32_t next = is_32bit_aligned
|
43 |
-
|
44 |
-
|
|
|
|
|
45 |
} while (current != expected);
|
46 |
}
|
47 |
|
48 |
-
template<>
|
49 |
-
__device__ inline void atomic_add<c10::Half, float>(c10::Half*
|
50 |
-
|
|
|
51 |
|
52 |
size_t offset = (size_t)address & 0x2;
|
53 |
-
volatile uint16_t*
|
54 |
-
reinterpret_cast<volatile
|
55 |
-
volatile uint32_t*
|
56 |
-
reinterpret_cast<volatile
|
57 |
bool is_32bit_aligned = offset == 0;
|
58 |
|
59 |
uint32_t current = address_as_uint[0];
|
@@ -63,11 +69,12 @@ __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float v
|
|
63 |
expected = current;
|
64 |
c10::Half current_half(address_as_short[0], c10::Half::from_bits());
|
65 |
c10::Half next_half = current_half + value;
|
66 |
-
uint32_t next = is_32bit_aligned
|
67 |
-
|
68 |
-
|
|
|
|
|
69 |
} while (current != expected);
|
70 |
-
|
71 |
}
|
72 |
|
73 |
} // namespace motif
|
|
|
1 |
#pragma once
|
2 |
|
|
|
3 |
#include <c10/util/BFloat16.h>
|
4 |
#include <c10/util/Half.h>
|
5 |
+
#include <cuda.h>
|
6 |
|
7 |
namespace motif {
|
8 |
+
template <typename scalar_t, typename acc_t>
|
9 |
+
__device__ inline void atomic_add(scalar_t *address, acc_t value) {
|
10 |
// TODO: change assert to a static_assert if possible
|
11 |
+
assert(false && "Unsupported type for atomic_add");
|
12 |
}
|
13 |
|
14 |
+
template <>
|
15 |
+
__device__ inline void atomic_add<float, float>(float *address, float value) {
|
16 |
+
atomicAdd(address, value);
|
17 |
}
|
18 |
|
19 |
+
template <>
|
20 |
+
__device__ inline void atomic_add<double, double>(double *address,
|
21 |
+
double value) {
|
22 |
+
atomicAdd(address, value);
|
23 |
}
|
24 |
|
25 |
+
template <>
|
26 |
+
__device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16 *_address,
|
27 |
+
float value) {
|
28 |
+
volatile c10::BFloat16 *address =
|
29 |
+
const_cast<volatile c10::BFloat16 *>(_address);
|
30 |
|
31 |
size_t offset = (size_t)address & 0x2;
|
32 |
+
volatile uint16_t *address_as_short = reinterpret_cast<volatile uint16_t *>(
|
33 |
+
reinterpret_cast<volatile char *>(address));
|
34 |
+
volatile uint32_t *address_as_uint = reinterpret_cast<volatile uint *>(
|
35 |
+
reinterpret_cast<volatile char *>(address) - offset);
|
36 |
bool is_32bit_aligned = offset == 0;
|
37 |
|
38 |
uint32_t current = address_as_uint[0];
|
|
|
42 |
expected = current;
|
43 |
c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
|
44 |
c10::BFloat16 next_bf16 = current_bf16 + value;
|
45 |
+
uint32_t next = is_32bit_aligned
|
46 |
+
? (current & 0xffff0000) | next_bf16.x
|
47 |
+
: (current & 0x0000ffff) | (next_bf16.x << 16);
|
48 |
+
current =
|
49 |
+
atomicCAS(const_cast<uint32_t *>(address_as_uint), expected, next);
|
50 |
} while (current != expected);
|
51 |
}
|
52 |
|
53 |
+
template <>
|
54 |
+
__device__ inline void atomic_add<c10::Half, float>(c10::Half *_address,
|
55 |
+
float value) {
|
56 |
+
volatile c10::Half *address = const_cast<volatile c10::Half *>(_address);
|
57 |
|
58 |
size_t offset = (size_t)address & 0x2;
|
59 |
+
volatile uint16_t *address_as_short = reinterpret_cast<volatile uint16_t *>(
|
60 |
+
reinterpret_cast<volatile char *>(address));
|
61 |
+
volatile uint32_t *address_as_uint = reinterpret_cast<volatile uint *>(
|
62 |
+
reinterpret_cast<volatile char *>(address) - offset);
|
63 |
bool is_32bit_aligned = offset == 0;
|
64 |
|
65 |
uint32_t current = address_as_uint[0];
|
|
|
69 |
expected = current;
|
70 |
c10::Half current_half(address_as_short[0], c10::Half::from_bits());
|
71 |
c10::Half next_half = current_half + value;
|
72 |
+
uint32_t next = is_32bit_aligned
|
73 |
+
? (current & 0xffff0000) | next_half.x
|
74 |
+
: (current & 0x0000ffff) | (next_half.x << 16);
|
75 |
+
current =
|
76 |
+
atomicCAS(const_cast<uint32_t *>(address_as_uint), expected, next);
|
77 |
} while (current != expected);
|
|
|
78 |
}
|
79 |
|
80 |
} // namespace motif
|
activation/block_reduce.h
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
namespace motif {
|
2 |
|
3 |
template <typename acc_t, int BLOCK_SIZE>
|
4 |
-
__device__ acc_t _block_reduce_sum(acc_t*
|
|
|
5 |
// TODO: Optimize with warp-level primitives
|
6 |
__syncthreads();
|
7 |
|
@@ -17,4 +18,4 @@ __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d)
|
|
17 |
return shared[0];
|
18 |
}
|
19 |
|
20 |
-
} // motif
|
|
|
1 |
namespace motif {
|
2 |
|
3 |
template <typename acc_t, int BLOCK_SIZE>
|
4 |
+
__device__ acc_t _block_reduce_sum(acc_t *shared, const float val,
|
5 |
+
const int d) {
|
6 |
// TODO: Optimize with warp-level primitives
|
7 |
__syncthreads();
|
8 |
|
|
|
18 |
return shared[0];
|
19 |
}
|
20 |
|
21 |
+
} // namespace motif
|
activation/dispatch_utils.h
CHANGED
@@ -6,10 +6,11 @@
|
|
6 |
|
7 |
#include <torch/all.h>
|
8 |
|
9 |
-
#define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...)
|
10 |
-
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
|
11 |
-
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
|
12 |
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
|
14 |
-
#define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)
|
15 |
-
AT_DISPATCH_SWITCH(TYPE, NAME,
|
|
|
|
6 |
|
7 |
#include <torch/all.h>
|
8 |
|
9 |
+
#define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
|
10 |
+
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
11 |
+
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
12 |
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
13 |
|
14 |
+
#define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
15 |
+
AT_DISPATCH_SWITCH(TYPE, NAME, \
|
16 |
+
MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
activation/rms_norm.cu
CHANGED
@@ -1,26 +1,23 @@
|
|
1 |
-
#include <ATen/cuda/CUDAContext.h>
|
2 |
#include <ATen/Functions.h>
|
3 |
-
#include <
|
4 |
#include <c10/cuda/CUDAGuard.h>
|
|
|
5 |
|
6 |
#include <cmath>
|
7 |
|
8 |
-
#include "cuda_compat.h"
|
9 |
-
#include "dispatch_utils.h"
|
10 |
#include "assert_utils.h"
|
11 |
#include "atomic_utils.h"
|
12 |
#include "block_reduce.h"
|
|
|
|
|
13 |
|
14 |
namespace motif {
|
15 |
|
16 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
17 |
-
__global__ void rms_norm_kernel(
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
const float eps,
|
22 |
-
const int d
|
23 |
-
) {
|
24 |
|
25 |
const int64_t token_idx = blockIdx.x;
|
26 |
const int64_t vec_idx = threadIdx.x;
|
@@ -44,15 +41,13 @@ __global__ void rms_norm_kernel(
|
|
44 |
}
|
45 |
|
46 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
47 |
-
__global__ void
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
const int d
|
55 |
-
) {
|
56 |
const int64_t token_idx = blockIdx.x;
|
57 |
const int64_t vec_idx = threadIdx.x;
|
58 |
acc_t d_sum = 0.0f;
|
@@ -80,8 +75,7 @@ __global__ void rms_norm_backward_kernel(
|
|
80 |
acc_t dy = output_grad[token_idx * d + idx];
|
81 |
acc_t w = weight[idx];
|
82 |
|
83 |
-
input_grad[token_idx * d + idx] =
|
84 |
-
scale * dy * w - dxx * x;
|
85 |
|
86 |
if (temp_weight_grad) {
|
87 |
temp_weight_grad[token_idx * d + idx] = dy * x * scale;
|
@@ -89,14 +83,12 @@ __global__ void rms_norm_backward_kernel(
|
|
89 |
}
|
90 |
}
|
91 |
|
92 |
-
}
|
93 |
-
|
94 |
|
95 |
-
void rms_norm(torch::Tensor&
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
{
|
100 |
AssertTensorShapeEqual(input, out, "input", "out");
|
101 |
AssertTensorNotNull(weight, "weight");
|
102 |
// TODO shape check
|
@@ -110,25 +102,20 @@ void rms_norm(torch::Tensor& out, // [..., d]
|
|
110 |
|
111 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
112 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
113 |
-
MOTIF_DISPATCH_FLOATING_TYPES(
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
weight.data_ptr<scalar_t>(),
|
120 |
-
eps, d);
|
121 |
-
}
|
122 |
-
);
|
123 |
}
|
124 |
|
125 |
-
void rms_norm_backward(
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
double eps) {
|
132 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
133 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
134 |
AssertTensorNotNull(weight, "weight");
|
@@ -143,24 +130,20 @@ void rms_norm_backward(
|
|
143 |
dim3 block(BLOCK_SIZE);
|
144 |
|
145 |
torch::Tensor temp_weight_grad =
|
146 |
-
|
147 |
-
input.options().dtype(torch::kFloat));
|
148 |
|
149 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
150 |
|
151 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
152 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
eps, d);
|
162 |
-
}
|
163 |
-
);
|
164 |
|
165 |
if (weight_grad.defined()) {
|
166 |
at::sum_out(weight_grad, temp_weight_grad, {0});
|
|
|
|
|
1 |
#include <ATen/Functions.h>
|
2 |
+
#include <ATen/cuda/CUDAContext.h>
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
4 |
+
#include <torch/all.h>
|
5 |
|
6 |
#include <cmath>
|
7 |
|
|
|
|
|
8 |
#include "assert_utils.h"
|
9 |
#include "atomic_utils.h"
|
10 |
#include "block_reduce.h"
|
11 |
+
#include "cuda_compat.h"
|
12 |
+
#include "dispatch_utils.h"
|
13 |
|
14 |
namespace motif {
|
15 |
|
16 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
17 |
+
__global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
|
18 |
+
const scalar_t *__restrict__ input, // [..., d]
|
19 |
+
const scalar_t *__restrict__ weight, // [d]
|
20 |
+
const float eps, const int d) {
|
|
|
|
|
|
|
21 |
|
22 |
const int64_t token_idx = blockIdx.x;
|
23 |
const int64_t vec_idx = threadIdx.x;
|
|
|
41 |
}
|
42 |
|
43 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
44 |
+
__global__ void
|
45 |
+
rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
|
46 |
+
acc_t *__restrict__ temp_weight_grad, // [..., d]
|
47 |
+
const scalar_t *__restrict__ output_grad, // [..., d]
|
48 |
+
const scalar_t *__restrict__ input, // [..., d]
|
49 |
+
const scalar_t *__restrict__ weight, // [d]
|
50 |
+
const float eps, const int d) {
|
|
|
|
|
51 |
const int64_t token_idx = blockIdx.x;
|
52 |
const int64_t vec_idx = threadIdx.x;
|
53 |
acc_t d_sum = 0.0f;
|
|
|
75 |
acc_t dy = output_grad[token_idx * d + idx];
|
76 |
acc_t w = weight[idx];
|
77 |
|
78 |
+
input_grad[token_idx * d + idx] = scale * dy * w - dxx * x;
|
|
|
79 |
|
80 |
if (temp_weight_grad) {
|
81 |
temp_weight_grad[token_idx * d + idx] = dy * x * scale;
|
|
|
83 |
}
|
84 |
}
|
85 |
|
86 |
+
} // namespace motif
|
|
|
87 |
|
88 |
+
void rms_norm(torch::Tensor &out, // [..., d]
|
89 |
+
const torch::Tensor &input, // [..., d]
|
90 |
+
const torch::Tensor &weight, // [d]
|
91 |
+
double eps) {
|
|
|
92 |
AssertTensorShapeEqual(input, out, "input", "out");
|
93 |
AssertTensorNotNull(weight, "weight");
|
94 |
// TODO shape check
|
|
|
102 |
|
103 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
104 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
105 |
+
MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
106 |
+
motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
|
107 |
+
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),
|
108 |
+
input.data_ptr<scalar_t>(),
|
109 |
+
weight.data_ptr<scalar_t>(), eps, d);
|
110 |
+
});
|
|
|
|
|
|
|
|
|
111 |
}
|
112 |
|
113 |
+
void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
|
114 |
+
torch::Tensor &weight_grad, // [..., d]
|
115 |
+
const torch::Tensor &output_grad, // [d]
|
116 |
+
const torch::Tensor &input, // [d]
|
117 |
+
const torch::Tensor &weight, // [d]
|
118 |
+
double eps) {
|
|
|
119 |
AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
|
120 |
AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
|
121 |
AssertTensorNotNull(weight, "weight");
|
|
|
130 |
dim3 block(BLOCK_SIZE);
|
131 |
|
132 |
torch::Tensor temp_weight_grad =
|
133 |
+
torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
|
|
|
134 |
|
135 |
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
136 |
|
137 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
138 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
139 |
+
input.scalar_type(), "rms_norm_backward_kernel", [&] {
|
140 |
+
motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
|
141 |
+
<<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(),
|
142 |
+
temp_weight_grad.data_ptr<float>(),
|
143 |
+
output_grad.data_ptr<scalar_t>(),
|
144 |
+
input.data_ptr<scalar_t>(),
|
145 |
+
weight.data_ptr<scalar_t>(), eps, d);
|
146 |
+
});
|
|
|
|
|
|
|
147 |
|
148 |
if (weight_grad.defined()) {
|
149 |
at::sum_out(weight_grad, temp_weight_grad, {0});
|
tests/conftest.py
CHANGED
@@ -33,8 +33,7 @@ def plot(perf_results: list[PerfResult]):
|
|
33 |
textfont=dict(size=14),
|
34 |
textposition="outside",
|
35 |
# width=[bar_width] * len(x_labels),
|
36 |
-
)
|
37 |
-
)
|
38 |
|
39 |
fig.add_trace(
|
40 |
go.Bar(
|
@@ -46,12 +45,12 @@ def plot(perf_results: list[PerfResult]):
|
|
46 |
textfont=dict(size=14),
|
47 |
textposition="outside",
|
48 |
# width=[bar_width] * len(x_labels),
|
49 |
-
)
|
50 |
-
)
|
51 |
|
52 |
fig.update_layout(
|
53 |
title=dict(
|
54 |
-
text=
|
|
|
55 |
font=dict(size=24),
|
56 |
),
|
57 |
legend=dict(
|
@@ -96,12 +95,14 @@ def plot(perf_results: list[PerfResult]):
|
|
96 |
|
97 |
|
98 |
def pytest_addoption(parser):
|
99 |
-
parser.addoption(
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
105 |
|
106 |
|
107 |
@pytest.fixture
|
@@ -117,10 +118,10 @@ def pytest_configure(config):
|
|
117 |
if DO_PLOT and not run_perf:
|
118 |
raise ValueError(
|
119 |
"Cannot plot performance results without running performance tests. "
|
120 |
-
"Please use --run-perf option."
|
121 |
-
)
|
122 |
|
123 |
-
config.addinivalue_line("markers",
|
|
|
124 |
|
125 |
|
126 |
def pytest_collection_modifyitems(config, items):
|
@@ -128,8 +129,7 @@ def pytest_collection_modifyitems(config, items):
|
|
128 |
|
129 |
skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
|
130 |
skip_normal = pytest.mark.skip(
|
131 |
-
reason="normal tests skipped when --run-perf is used"
|
132 |
-
)
|
133 |
for item in items:
|
134 |
if "perf" in item.keywords and not run_perf:
|
135 |
item.add_marker(skip_perf)
|
|
|
33 |
textfont=dict(size=14),
|
34 |
textposition="outside",
|
35 |
# width=[bar_width] * len(x_labels),
|
36 |
+
))
|
|
|
37 |
|
38 |
fig.add_trace(
|
39 |
go.Bar(
|
|
|
45 |
textfont=dict(size=14),
|
46 |
textposition="outside",
|
47 |
# width=[bar_width] * len(x_labels),
|
48 |
+
))
|
|
|
49 |
|
50 |
fig.update_layout(
|
51 |
title=dict(
|
52 |
+
text=
|
53 |
+
"<b>Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)</b>",
|
54 |
font=dict(size=24),
|
55 |
),
|
56 |
legend=dict(
|
|
|
95 |
|
96 |
|
97 |
def pytest_addoption(parser):
|
98 |
+
parser.addoption("--run-perf",
|
99 |
+
action="store_true",
|
100 |
+
default=False,
|
101 |
+
help="Run perf tests")
|
102 |
+
parser.addoption("--do-plot",
|
103 |
+
action="store_true",
|
104 |
+
default=False,
|
105 |
+
help="Plot performance results")
|
106 |
|
107 |
|
108 |
@pytest.fixture
|
|
|
118 |
if DO_PLOT and not run_perf:
|
119 |
raise ValueError(
|
120 |
"Cannot plot performance results without running performance tests. "
|
121 |
+
"Please use --run-perf option.")
|
|
|
122 |
|
123 |
+
config.addinivalue_line("markers",
|
124 |
+
"perf: mark test as performance-related")
|
125 |
|
126 |
|
127 |
def pytest_collection_modifyitems(config, items):
|
|
|
129 |
|
130 |
skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
|
131 |
skip_normal = pytest.mark.skip(
|
132 |
+
reason="normal tests skipped when --run-perf is used")
|
|
|
133 |
for item in items:
|
134 |
if "perf" in item.keywords and not run_perf:
|
135 |
item.add_marker(skip_perf)
|
tests/kernels/allclose_default.py
CHANGED
@@ -3,7 +3,11 @@ import torch
|
|
3 |
# Reference default values of atol and rtol are from
|
4 |
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
5 |
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
6 |
-
default_rtol = {
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def get_default_atol(output) -> float:
|
|
|
3 |
# Reference default values of atol and rtol are from
|
4 |
# https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
|
5 |
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
|
6 |
+
default_rtol = {
|
7 |
+
torch.float16: 1e-3,
|
8 |
+
torch.bfloat16: 1.6e-2,
|
9 |
+
torch.float: 1.3e-6
|
10 |
+
}
|
11 |
|
12 |
|
13 |
def get_default_atol(output) -> float:
|
tests/kernels/test_poly_norm.py
CHANGED
@@ -13,23 +13,20 @@ DTYPES = [torch.float, torch.bfloat16, torch.half]
|
|
13 |
NUM_TOKENS = [7, 13] # Arbitrary values for testing
|
14 |
D = [513] # Arbitrary values for testing
|
15 |
SEEDS = [0]
|
16 |
-
CUDA_DEVICES = [
|
|
|
|
|
17 |
|
18 |
|
19 |
def norm(x, eps: float) -> torch.Tensor:
|
20 |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
21 |
|
22 |
|
23 |
-
def poly_norm(
|
24 |
-
|
25 |
-
) -> torch.Tensor:
|
26 |
x = x.float()
|
27 |
-
return (
|
28 |
-
|
29 |
-
+ weight[1] * norm(x**2, eps)
|
30 |
-
+ weight[2] * norm(x, eps)
|
31 |
-
+ bias
|
32 |
-
).to(weight.dtype)
|
33 |
|
34 |
|
35 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
|
13 |
NUM_TOKENS = [7, 13] # Arbitrary values for testing
|
14 |
D = [513] # Arbitrary values for testing
|
15 |
SEEDS = [0]
|
16 |
+
CUDA_DEVICES = [
|
17 |
+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
18 |
+
]
|
19 |
|
20 |
|
21 |
def norm(x, eps: float) -> torch.Tensor:
|
22 |
return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
|
23 |
|
24 |
|
25 |
+
def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
|
26 |
+
eps: float) -> torch.Tensor:
|
|
|
27 |
x = x.float()
|
28 |
+
return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
|
29 |
+
weight[2] * norm(x, eps) + bias).to(weight.dtype)
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
tests/kernels/test_poly_norm_perf.py
CHANGED
@@ -94,7 +94,8 @@ def test_poly_norm(
|
|
94 |
return start.elapsed_time(end) / NUM_REP
|
95 |
|
96 |
kernel_time_ms = time_cuda(lambda: layer(x))
|
97 |
-
torch_fn_time = time_cuda(
|
|
|
98 |
|
99 |
PERF_RESULTS.append(
|
100 |
PerfResult(
|
@@ -103,11 +104,12 @@ def test_poly_norm(
|
|
103 |
dtype=dtype,
|
104 |
kernel_time_ms=kernel_time_ms,
|
105 |
torch_time_ms=torch_fn_time,
|
106 |
-
)
|
107 |
-
)
|
108 |
|
109 |
-
kernel_time_ms = time_cuda(
|
110 |
-
|
|
|
|
|
111 |
|
112 |
PERF_RESULTS.append(
|
113 |
PerfResult(
|
@@ -116,5 +118,4 @@ def test_poly_norm(
|
|
116 |
dtype=dtype,
|
117 |
kernel_time_ms=kernel_time_ms,
|
118 |
torch_time_ms=torch_fn_time,
|
119 |
-
)
|
120 |
-
)
|
|
|
94 |
return start.elapsed_time(end) / NUM_REP
|
95 |
|
96 |
kernel_time_ms = time_cuda(lambda: layer(x))
|
97 |
+
torch_fn_time = time_cuda(
|
98 |
+
lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
|
99 |
|
100 |
PERF_RESULTS.append(
|
101 |
PerfResult(
|
|
|
104 |
dtype=dtype,
|
105 |
kernel_time_ms=kernel_time_ms,
|
106 |
torch_time_ms=torch_fn_time,
|
107 |
+
))
|
|
|
108 |
|
109 |
+
kernel_time_ms = time_cuda(
|
110 |
+
lambda: mod_out.backward(out_grad, retain_graph=True))
|
111 |
+
torch_fn_time = time_cuda(
|
112 |
+
lambda: ref_out.backward(out_grad, retain_graph=True))
|
113 |
|
114 |
PERF_RESULTS.append(
|
115 |
PerfResult(
|
|
|
118 |
dtype=dtype,
|
119 |
kernel_time_ms=kernel_time_ms,
|
120 |
torch_time_ms=torch_fn_time,
|
121 |
+
))
|
|
tests/kernels/test_rms_norm.py
CHANGED
@@ -13,7 +13,9 @@ DTYPES = [torch.float, torch.bfloat16, torch.half]
|
|
13 |
NUM_TOKENS = [7, 13] # Arbitrary values for testing
|
14 |
D = [513] # Arbitrary values for testing
|
15 |
SEEDS = [0]
|
16 |
-
CUDA_DEVICES = [
|
|
|
|
|
17 |
|
18 |
|
19 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
|
|
13 |
NUM_TOKENS = [7, 13] # Arbitrary values for testing
|
14 |
D = [513] # Arbitrary values for testing
|
15 |
SEEDS = [0]
|
16 |
+
CUDA_DEVICES = [
|
17 |
+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
18 |
+
]
|
19 |
|
20 |
|
21 |
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
tests/kernels/utils.py
CHANGED
@@ -46,15 +46,19 @@ def fp8_allclose(
|
|
46 |
"""
|
47 |
Reference implementation of torch.allclose
|
48 |
"""
|
49 |
-
torch._refs._check_close_args(name="torch.allclose",
|
|
|
|
|
|
|
|
|
50 |
|
51 |
return bool(
|
52 |
torch.all(
|
53 |
-
torch.isclose(
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
|
60 |
# A special version of op check that has a restricted default set of test_utils
|
@@ -73,10 +77,9 @@ def opcheck(
|
|
73 |
cond: bool = True,
|
74 |
) -> Dict[str, str]:
|
75 |
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
76 |
-
return (
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
)
|
|
|
46 |
"""
|
47 |
Reference implementation of torch.allclose
|
48 |
"""
|
49 |
+
torch._refs._check_close_args(name="torch.allclose",
|
50 |
+
a=a,
|
51 |
+
b=b,
|
52 |
+
rtol=rtol,
|
53 |
+
atol=atol)
|
54 |
|
55 |
return bool(
|
56 |
torch.all(
|
57 |
+
torch.isclose(a.double(),
|
58 |
+
b.double(),
|
59 |
+
rtol=rtol,
|
60 |
+
atol=atol,
|
61 |
+
equal_nan=equal_nan)).item())
|
62 |
|
63 |
|
64 |
# A special version of op check that has a restricted default set of test_utils
|
|
|
77 |
cond: bool = True,
|
78 |
) -> Dict[str, str]:
|
79 |
with unittest.mock.patch("torch.allclose", new=fp8_allclose):
|
80 |
+
return (torch.library.opcheck(op,
|
81 |
+
args,
|
82 |
+
kwargs,
|
83 |
+
test_utils=test_utils,
|
84 |
+
raise_exception=raise_exception)
|
85 |
+
if cond else {})
|
|
torch-ext/activation/layers.py
CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
|
|
7 |
|
8 |
|
9 |
class PolyNorm(nn.Module):
|
|
|
10 |
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
11 |
super().__init__()
|
12 |
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
|
|
28 |
|
29 |
|
30 |
class RMSNorm(nn.Module):
|
|
|
31 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
32 |
super().__init__()
|
33 |
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
|
|
7 |
|
8 |
|
9 |
class PolyNorm(nn.Module):
|
10 |
+
|
11 |
def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
|
12 |
super().__init__()
|
13 |
self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
|
|
|
29 |
|
30 |
|
31 |
class RMSNorm(nn.Module):
|
32 |
+
|
33 |
def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
|
34 |
super().__init__()
|
35 |
self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
|
torch-ext/activation/poly_norm.py
CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
|
|
26 |
input, weight = ctx.saved_tensors
|
27 |
eps = ctx.eps
|
28 |
|
29 |
-
input_grad = torch.empty_like(
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
|
39 |
-
)
|
40 |
|
41 |
return input_grad, weight_grad, bias_grad, None
|
|
|
26 |
input, weight = ctx.saved_tensors
|
27 |
eps = ctx.eps
|
28 |
|
29 |
+
input_grad = torch.empty_like(
|
30 |
+
input) if ctx.needs_input_grad[0] else None
|
31 |
+
weight_grad = torch.empty_like(
|
32 |
+
weight) if ctx.needs_input_grad[1] else None
|
33 |
+
bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
|
34 |
+
if ctx.needs_input_grad[2] else None)
|
35 |
+
|
36 |
+
ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
|
37 |
+
input, weight, eps)
|
|
|
|
|
38 |
|
39 |
return input_grad, weight_grad, bias_grad, None
|
torch-ext/activation/rms_norm.py
CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
|
|
26 |
input, weight = ctx.saved_tensors
|
27 |
eps = ctx.eps
|
28 |
|
29 |
-
input_grad = torch.empty_like(
|
30 |
-
|
|
|
|
|
31 |
|
32 |
-
ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
|
|
|
33 |
|
34 |
return input_grad, weight_grad, None
|
|
|
26 |
input, weight = ctx.saved_tensors
|
27 |
eps = ctx.eps
|
28 |
|
29 |
+
input_grad = torch.empty_like(
|
30 |
+
input) if ctx.needs_input_grad[0] else None
|
31 |
+
weight_grad = torch.empty_like(
|
32 |
+
weight) if ctx.needs_input_grad[1] else None
|
33 |
|
34 |
+
ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
|
35 |
+
weight, eps)
|
36 |
|
37 |
return input_grad, weight_grad, None
|
torch-ext/torch_binding.cpp
CHANGED
@@ -5,18 +5,21 @@
|
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
// Activation ops
|
8 |
-
ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias,
|
9 |
-
|
|
|
|
|
|
|
10 |
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
11 |
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
12 |
|
13 |
// Activation ops
|
14 |
-
ops.def(
|
15 |
-
|
|
|
|
|
16 |
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
17 |
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
18 |
-
|
19 |
-
|
20 |
}
|
21 |
|
22 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
|
5 |
|
6 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
7 |
// Activation ops
|
8 |
+
ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, "
|
9 |
+
"float eps) -> ()");
|
10 |
+
ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! "
|
11 |
+
"bias_grad, Tensor output_grad, Tensor input, Tensor weight, float "
|
12 |
+
"eps) -> ()");
|
13 |
ops.impl("poly_norm", torch::kCUDA, &poly_norm);
|
14 |
ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
|
15 |
|
16 |
// Activation ops
|
17 |
+
ops.def(
|
18 |
+
"rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
|
19 |
+
ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor "
|
20 |
+
"output_grad, Tensor input, Tensor weight, float eps) -> ()");
|
21 |
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
22 |
ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
|
|
|
|
|
23 |
}
|
24 |
|
25 |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
CHANGED
@@ -2,8 +2,18 @@
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
-
void poly_norm(torch::Tensor &out, const torch::Tensor &input,
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
-
void rms_norm(torch::Tensor &out, const torch::Tensor &input,
|
9 |
-
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
#include <torch/torch.h>
|
4 |
|
5 |
+
void poly_norm(torch::Tensor &out, const torch::Tensor &input,
|
6 |
+
const torch::Tensor &weights, const torch::Tensor &bias,
|
7 |
+
double eps);
|
8 |
+
void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
|
9 |
+
torch::Tensor &bias_grad,
|
10 |
+
const torch::Tensor &output_grad,
|
11 |
+
const torch::Tensor &input, const torch::Tensor &weight,
|
12 |
+
double eps);
|
13 |
|
14 |
+
void rms_norm(torch::Tensor &out, const torch::Tensor &input,
|
15 |
+
const torch::Tensor &weights, double eps);
|
16 |
+
void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
|
17 |
+
const torch::Tensor &output_grad,
|
18 |
+
const torch::Tensor &input, const torch::Tensor &weight,
|
19 |
+
double eps);
|