TaehyunKimMotif commited on
Commit
f517c97
·
1 Parent(s): a73a0c0

add readme with precommit hooks and applied pre commit to all files

Browse files
README.md CHANGED
@@ -32,6 +32,7 @@ print(poly_norm(x))
32
 
33
  - Test cases are from the Motif LLM
34
  - You can reproduce the results with:
 
35
  ```bash
36
  cd tests
37
  pytest --run-perf --do-plot
@@ -39,3 +40,47 @@ pytest --run-perf --do-plot
39
 
40
  ![PolyNorm Performance](./tests/perf.png)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  - Test cases are from the Motif LLM
34
  - You can reproduce the results with:
35
+
36
  ```bash
37
  cd tests
38
  pytest --run-perf --do-plot
 
40
 
41
  ![PolyNorm Performance](./tests/perf.png)
42
 
43
+ ## Pre-commit Hooks
44
+
45
+ This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
46
+
47
+ ### Setup
48
+
49
+ 1. Install pre-commit:
50
+
51
+ ```bash
52
+ pip install pre-commit
53
+ ```
54
+
55
+ 2. Install the git hooks:
56
+
57
+ ```bash
58
+ pre-commit install
59
+ ```
60
+
61
+ Once installed, the configured hooks will run automatically on each commit.
62
+
63
+ ### Included Hooks
64
+
65
+ The following tools are run via pre-commit:
66
+
67
+ - **[yapf](https://github.com/google/yapf)** – Python code formatter
68
+ - **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
69
+ - **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
70
+ - **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
71
+ - **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
72
+ - **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
73
+
74
+ ### Usage
75
+
76
+ - Run all checks on the entire codebase:
77
+
78
+ ```bash
79
+ pre-commit run --all-files
80
+ ```
81
+
82
+ - Run a specific hook (example: isort):
83
+
84
+ ```bash
85
+ pre-commit run isort --all-files
86
+ ```
activation/assert_utils.h CHANGED
@@ -3,12 +3,15 @@
3
  #include <ATen/cuda/CUDAContext.h>
4
  #include <torch/all.h>
5
 
6
- inline void AssertTensorNotNull(const torch::Tensor &tensor, const std::string &name) {
 
7
  TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
8
  }
9
 
10
- inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::Tensor &tensor_b,
11
- const std::string &name_a, const std::string &name_b) {
 
 
12
 
13
  AssertTensorNotNull(tensor_a, name_a);
14
  AssertTensorNotNull(tensor_b, name_b);
@@ -17,6 +20,7 @@ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::T
17
  auto tensor_shape_b = tensor_b.sizes();
18
 
19
  TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
20
- "{} tensor shape should be equal to {} tensor shape. (actual: {}, expected: {})",
21
- name_a, name_b, tensor_shape_a, tensor_shape_b);
 
22
  }
 
3
  #include <ATen/cuda/CUDAContext.h>
4
  #include <torch/all.h>
5
 
6
+ inline void AssertTensorNotNull(const torch::Tensor &tensor,
7
+ const std::string &name) {
8
  TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
9
  }
10
 
11
+ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
12
+ const torch::Tensor &tensor_b,
13
+ const std::string &name_a,
14
+ const std::string &name_b) {
15
 
16
  AssertTensorNotNull(tensor_a, name_a);
17
  AssertTensorNotNull(tensor_b, name_b);
 
20
  auto tensor_shape_b = tensor_b.sizes();
21
 
22
  TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
23
+ "{} tensor shape should be equal to {} tensor shape. "
24
+ "(actual: {}, expected: {})",
25
+ name_a, name_b, tensor_shape_a, tensor_shape_b);
26
  }
activation/atomic_utils.h CHANGED
@@ -1,35 +1,38 @@
1
  #pragma once
2
 
3
- #include <cuda.h>
4
  #include <c10/util/BFloat16.h>
5
  #include <c10/util/Half.h>
 
6
 
7
  namespace motif {
8
- template<typename scalar_t, typename acc_t>
9
- __device__ inline void atomic_add(scalar_t* address, acc_t value) {
10
  // TODO: change assert to a static_assert if possible
11
- assert(false && "Unsupported type for atomic_add");
12
  }
13
 
14
- template<>
15
- __device__ inline void atomic_add<float, float>(float* address, float value) {
16
- atomicAdd(address, value);
17
  }
18
 
19
- template<>
20
- __device__ inline void atomic_add<double, double>(double* address, double value) {
21
- atomicAdd(address, value);
 
22
  }
23
 
24
- template<>
25
- __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address, float value) {
26
- volatile c10::BFloat16* address = const_cast<volatile c10::BFloat16*>(_address);
 
 
27
 
28
  size_t offset = (size_t)address & 0x2;
29
- volatile uint16_t* address_as_short =
30
- reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
31
- volatile uint32_t* address_as_uint =
32
- reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
33
  bool is_32bit_aligned = offset == 0;
34
 
35
  uint32_t current = address_as_uint[0];
@@ -39,21 +42,24 @@ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address,
39
  expected = current;
40
  c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
41
  c10::BFloat16 next_bf16 = current_bf16 + value;
42
- uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
43
- : (current & 0x0000ffff) | (next_bf16.x << 16);
44
- current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
 
 
45
  } while (current != expected);
46
  }
47
 
48
- template<>
49
- __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float value) {
50
- volatile c10::Half* address = const_cast<volatile c10::Half*>(_address);
 
51
 
52
  size_t offset = (size_t)address & 0x2;
53
- volatile uint16_t* address_as_short =
54
- reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
55
- volatile uint32_t* address_as_uint =
56
- reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
57
  bool is_32bit_aligned = offset == 0;
58
 
59
  uint32_t current = address_as_uint[0];
@@ -63,11 +69,12 @@ __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float v
63
  expected = current;
64
  c10::Half current_half(address_as_short[0], c10::Half::from_bits());
65
  c10::Half next_half = current_half + value;
66
- uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
67
- : (current & 0x0000ffff) | (next_half.x << 16);
68
- current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
 
 
69
  } while (current != expected);
70
-
71
  }
72
 
73
  } // namespace motif
 
1
  #pragma once
2
 
 
3
  #include <c10/util/BFloat16.h>
4
  #include <c10/util/Half.h>
5
+ #include <cuda.h>
6
 
7
  namespace motif {
8
+ template <typename scalar_t, typename acc_t>
9
+ __device__ inline void atomic_add(scalar_t *address, acc_t value) {
10
  // TODO: change assert to a static_assert if possible
11
+ assert(false && "Unsupported type for atomic_add");
12
  }
13
 
14
+ template <>
15
+ __device__ inline void atomic_add<float, float>(float *address, float value) {
16
+ atomicAdd(address, value);
17
  }
18
 
19
+ template <>
20
+ __device__ inline void atomic_add<double, double>(double *address,
21
+ double value) {
22
+ atomicAdd(address, value);
23
  }
24
 
25
+ template <>
26
+ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16 *_address,
27
+ float value) {
28
+ volatile c10::BFloat16 *address =
29
+ const_cast<volatile c10::BFloat16 *>(_address);
30
 
31
  size_t offset = (size_t)address & 0x2;
32
+ volatile uint16_t *address_as_short = reinterpret_cast<volatile uint16_t *>(
33
+ reinterpret_cast<volatile char *>(address));
34
+ volatile uint32_t *address_as_uint = reinterpret_cast<volatile uint *>(
35
+ reinterpret_cast<volatile char *>(address) - offset);
36
  bool is_32bit_aligned = offset == 0;
37
 
38
  uint32_t current = address_as_uint[0];
 
42
  expected = current;
43
  c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
44
  c10::BFloat16 next_bf16 = current_bf16 + value;
45
+ uint32_t next = is_32bit_aligned
46
+ ? (current & 0xffff0000) | next_bf16.x
47
+ : (current & 0x0000ffff) | (next_bf16.x << 16);
48
+ current =
49
+ atomicCAS(const_cast<uint32_t *>(address_as_uint), expected, next);
50
  } while (current != expected);
51
  }
52
 
53
+ template <>
54
+ __device__ inline void atomic_add<c10::Half, float>(c10::Half *_address,
55
+ float value) {
56
+ volatile c10::Half *address = const_cast<volatile c10::Half *>(_address);
57
 
58
  size_t offset = (size_t)address & 0x2;
59
+ volatile uint16_t *address_as_short = reinterpret_cast<volatile uint16_t *>(
60
+ reinterpret_cast<volatile char *>(address));
61
+ volatile uint32_t *address_as_uint = reinterpret_cast<volatile uint *>(
62
+ reinterpret_cast<volatile char *>(address) - offset);
63
  bool is_32bit_aligned = offset == 0;
64
 
65
  uint32_t current = address_as_uint[0];
 
69
  expected = current;
70
  c10::Half current_half(address_as_short[0], c10::Half::from_bits());
71
  c10::Half next_half = current_half + value;
72
+ uint32_t next = is_32bit_aligned
73
+ ? (current & 0xffff0000) | next_half.x
74
+ : (current & 0x0000ffff) | (next_half.x << 16);
75
+ current =
76
+ atomicCAS(const_cast<uint32_t *>(address_as_uint), expected, next);
77
  } while (current != expected);
 
78
  }
79
 
80
  } // namespace motif
activation/block_reduce.h CHANGED
@@ -1,7 +1,8 @@
1
  namespace motif {
2
 
3
  template <typename acc_t, int BLOCK_SIZE>
4
- __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) {
 
5
  // TODO: Optimize with warp-level primitives
6
  __syncthreads();
7
 
@@ -17,4 +18,4 @@ __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d)
17
  return shared[0];
18
  }
19
 
20
- } // motif
 
1
  namespace motif {
2
 
3
  template <typename acc_t, int BLOCK_SIZE>
4
+ __device__ acc_t _block_reduce_sum(acc_t *shared, const float val,
5
+ const int d) {
6
  // TODO: Optimize with warp-level primitives
7
  __syncthreads();
8
 
 
18
  return shared[0];
19
  }
20
 
21
+ } // namespace motif
activation/dispatch_utils.h CHANGED
@@ -6,10 +6,11 @@
6
 
7
  #include <torch/all.h>
8
 
9
- #define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
10
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
 
14
- #define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
- AT_DISPATCH_SWITCH(TYPE, NAME, MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 
 
6
 
7
  #include <torch/all.h>
8
 
9
+ #define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
 
14
+ #define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
16
+ MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
activation/rms_norm.cu CHANGED
@@ -1,26 +1,23 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
  #include <ATen/Functions.h>
3
- #include <torch/all.h>
4
  #include <c10/cuda/CUDAGuard.h>
 
5
 
6
  #include <cmath>
7
 
8
- #include "cuda_compat.h"
9
- #include "dispatch_utils.h"
10
  #include "assert_utils.h"
11
  #include "atomic_utils.h"
12
  #include "block_reduce.h"
 
 
13
 
14
  namespace motif {
15
 
16
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
- __global__ void rms_norm_kernel(
18
- scalar_t* __restrict__ out, // [..., d]
19
- const scalar_t* __restrict__ input, // [..., d]
20
- const scalar_t* __restrict__ weight, // [d]
21
- const float eps,
22
- const int d
23
- ) {
24
 
25
  const int64_t token_idx = blockIdx.x;
26
  const int64_t vec_idx = threadIdx.x;
@@ -44,15 +41,13 @@ __global__ void rms_norm_kernel(
44
  }
45
 
46
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
47
- __global__ void rms_norm_backward_kernel(
48
- scalar_t* __restrict__ input_grad, // [..., d]
49
- acc_t* __restrict__ temp_weight_grad, // [..., d]
50
- const scalar_t* __restrict__ output_grad, // [..., d]
51
- const scalar_t* __restrict__ input, // [..., d]
52
- const scalar_t* __restrict__ weight, // [d]
53
- const float eps,
54
- const int d
55
- ) {
56
  const int64_t token_idx = blockIdx.x;
57
  const int64_t vec_idx = threadIdx.x;
58
  acc_t d_sum = 0.0f;
@@ -80,8 +75,7 @@ __global__ void rms_norm_backward_kernel(
80
  acc_t dy = output_grad[token_idx * d + idx];
81
  acc_t w = weight[idx];
82
 
83
- input_grad[token_idx * d + idx] =
84
- scale * dy * w - dxx * x;
85
 
86
  if (temp_weight_grad) {
87
  temp_weight_grad[token_idx * d + idx] = dy * x * scale;
@@ -89,14 +83,12 @@ __global__ void rms_norm_backward_kernel(
89
  }
90
  }
91
 
92
- } // namespace motif
93
-
94
 
95
- void rms_norm(torch::Tensor& out, // [..., d]
96
- const torch::Tensor& input, // [..., d]
97
- const torch::Tensor& weight, // [d]
98
- double eps)
99
- {
100
  AssertTensorShapeEqual(input, out, "input", "out");
101
  AssertTensorNotNull(weight, "weight");
102
  // TODO shape check
@@ -110,25 +102,20 @@ void rms_norm(torch::Tensor& out, // [..., d]
110
 
111
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
112
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
113
- MOTIF_DISPATCH_FLOATING_TYPES(
114
- input.scalar_type(), "rms_norm_kernel", [&] {
115
- motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
116
- <<<grid, block, 0, stream>>>(
117
- out.data_ptr<scalar_t>(),
118
- input.data_ptr<scalar_t>(),
119
- weight.data_ptr<scalar_t>(),
120
- eps, d);
121
- }
122
- );
123
  }
124
 
125
- void rms_norm_backward(
126
- torch::Tensor& input_grad, // [..., d]
127
- torch::Tensor& weight_grad, // [..., d]
128
- const torch::Tensor& output_grad, // [d]
129
- const torch::Tensor& input, // [d]
130
- const torch::Tensor& weight, // [d]
131
- double eps) {
132
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
133
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
134
  AssertTensorNotNull(weight, "weight");
@@ -143,24 +130,20 @@ void rms_norm_backward(
143
  dim3 block(BLOCK_SIZE);
144
 
145
  torch::Tensor temp_weight_grad =
146
- torch::empty({num_tokens, d},
147
- input.options().dtype(torch::kFloat));
148
 
149
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150
 
151
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
152
  MOTIF_DISPATCH_FLOATING_TYPES(
153
- input.scalar_type(), "rms_norm_backward_kernel", [&] {
154
- motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
155
- <<<grid, block, 0, stream>>>(
156
- input_grad.data_ptr<scalar_t>(),
157
- temp_weight_grad.data_ptr<float>(),
158
- output_grad.data_ptr<scalar_t>(),
159
- input.data_ptr<scalar_t>(),
160
- weight.data_ptr<scalar_t>(),
161
- eps, d);
162
- }
163
- );
164
 
165
  if (weight_grad.defined()) {
166
  at::sum_out(weight_grad, temp_weight_grad, {0});
 
 
1
  #include <ATen/Functions.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
  #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/all.h>
5
 
6
  #include <cmath>
7
 
 
 
8
  #include "assert_utils.h"
9
  #include "atomic_utils.h"
10
  #include "block_reduce.h"
11
+ #include "cuda_compat.h"
12
+ #include "dispatch_utils.h"
13
 
14
  namespace motif {
15
 
16
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
+ __global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
18
+ const scalar_t *__restrict__ input, // [..., d]
19
+ const scalar_t *__restrict__ weight, // [d]
20
+ const float eps, const int d) {
 
 
 
21
 
22
  const int64_t token_idx = blockIdx.x;
23
  const int64_t vec_idx = threadIdx.x;
 
41
  }
42
 
43
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
44
+ __global__ void
45
+ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
46
+ acc_t *__restrict__ temp_weight_grad, // [..., d]
47
+ const scalar_t *__restrict__ output_grad, // [..., d]
48
+ const scalar_t *__restrict__ input, // [..., d]
49
+ const scalar_t *__restrict__ weight, // [d]
50
+ const float eps, const int d) {
 
 
51
  const int64_t token_idx = blockIdx.x;
52
  const int64_t vec_idx = threadIdx.x;
53
  acc_t d_sum = 0.0f;
 
75
  acc_t dy = output_grad[token_idx * d + idx];
76
  acc_t w = weight[idx];
77
 
78
+ input_grad[token_idx * d + idx] = scale * dy * w - dxx * x;
 
79
 
80
  if (temp_weight_grad) {
81
  temp_weight_grad[token_idx * d + idx] = dy * x * scale;
 
83
  }
84
  }
85
 
86
+ } // namespace motif
 
87
 
88
+ void rms_norm(torch::Tensor &out, // [..., d]
89
+ const torch::Tensor &input, // [..., d]
90
+ const torch::Tensor &weight, // [d]
91
+ double eps) {
 
92
  AssertTensorShapeEqual(input, out, "input", "out");
93
  AssertTensorNotNull(weight, "weight");
94
  // TODO shape check
 
102
 
103
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
104
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
105
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
106
+ motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
107
+ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),
108
+ input.data_ptr<scalar_t>(),
109
+ weight.data_ptr<scalar_t>(), eps, d);
110
+ });
 
 
 
 
111
  }
112
 
113
+ void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
114
+ torch::Tensor &weight_grad, // [..., d]
115
+ const torch::Tensor &output_grad, // [d]
116
+ const torch::Tensor &input, // [d]
117
+ const torch::Tensor &weight, // [d]
118
+ double eps) {
 
119
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
120
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
121
  AssertTensorNotNull(weight, "weight");
 
130
  dim3 block(BLOCK_SIZE);
131
 
132
  torch::Tensor temp_weight_grad =
133
+ torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
 
134
 
135
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
136
 
137
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
138
  MOTIF_DISPATCH_FLOATING_TYPES(
139
+ input.scalar_type(), "rms_norm_backward_kernel", [&] {
140
+ motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
141
+ <<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(),
142
+ temp_weight_grad.data_ptr<float>(),
143
+ output_grad.data_ptr<scalar_t>(),
144
+ input.data_ptr<scalar_t>(),
145
+ weight.data_ptr<scalar_t>(), eps, d);
146
+ });
 
 
 
147
 
148
  if (weight_grad.defined()) {
149
  at::sum_out(weight_grad, temp_weight_grad, {0});
tests/conftest.py CHANGED
@@ -33,8 +33,7 @@ def plot(perf_results: list[PerfResult]):
33
  textfont=dict(size=14),
34
  textposition="outside",
35
  # width=[bar_width] * len(x_labels),
36
- )
37
- )
38
 
39
  fig.add_trace(
40
  go.Bar(
@@ -46,12 +45,12 @@ def plot(perf_results: list[PerfResult]):
46
  textfont=dict(size=14),
47
  textposition="outside",
48
  # width=[bar_width] * len(x_labels),
49
- )
50
- )
51
 
52
  fig.update_layout(
53
  title=dict(
54
- text="<b>Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)</b>",
 
55
  font=dict(size=24),
56
  ),
57
  legend=dict(
@@ -96,12 +95,14 @@ def plot(perf_results: list[PerfResult]):
96
 
97
 
98
  def pytest_addoption(parser):
99
- parser.addoption(
100
- "--run-perf", action="store_true", default=False, help="Run perf tests"
101
- )
102
- parser.addoption(
103
- "--do-plot", action="store_true", default=False, help="Plot performance results"
104
- )
 
 
105
 
106
 
107
  @pytest.fixture
@@ -117,10 +118,10 @@ def pytest_configure(config):
117
  if DO_PLOT and not run_perf:
118
  raise ValueError(
119
  "Cannot plot performance results without running performance tests. "
120
- "Please use --run-perf option."
121
- )
122
 
123
- config.addinivalue_line("markers", "perf: mark test as performance-related")
 
124
 
125
 
126
  def pytest_collection_modifyitems(config, items):
@@ -128,8 +129,7 @@ def pytest_collection_modifyitems(config, items):
128
 
129
  skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
130
  skip_normal = pytest.mark.skip(
131
- reason="normal tests skipped when --run-perf is used"
132
- )
133
  for item in items:
134
  if "perf" in item.keywords and not run_perf:
135
  item.add_marker(skip_perf)
 
33
  textfont=dict(size=14),
34
  textposition="outside",
35
  # width=[bar_width] * len(x_labels),
36
+ ))
 
37
 
38
  fig.add_trace(
39
  go.Bar(
 
45
  textfont=dict(size=14),
46
  textposition="outside",
47
  # width=[bar_width] * len(x_labels),
48
+ ))
 
49
 
50
  fig.update_layout(
51
  title=dict(
52
+ text=
53
+ "<b>Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)</b>",
54
  font=dict(size=24),
55
  ),
56
  legend=dict(
 
95
 
96
 
97
  def pytest_addoption(parser):
98
+ parser.addoption("--run-perf",
99
+ action="store_true",
100
+ default=False,
101
+ help="Run perf tests")
102
+ parser.addoption("--do-plot",
103
+ action="store_true",
104
+ default=False,
105
+ help="Plot performance results")
106
 
107
 
108
  @pytest.fixture
 
118
  if DO_PLOT and not run_perf:
119
  raise ValueError(
120
  "Cannot plot performance results without running performance tests. "
121
+ "Please use --run-perf option.")
 
122
 
123
+ config.addinivalue_line("markers",
124
+ "perf: mark test as performance-related")
125
 
126
 
127
  def pytest_collection_modifyitems(config, items):
 
129
 
130
  skip_perf = pytest.mark.skip(reason="need --run-perf option to run")
131
  skip_normal = pytest.mark.skip(
132
+ reason="normal tests skipped when --run-perf is used")
 
133
  for item in items:
134
  if "perf" in item.keywords and not run_perf:
135
  item.add_marker(skip_perf)
tests/kernels/allclose_default.py CHANGED
@@ -3,7 +3,11 @@ import torch
3
  # Reference default values of atol and rtol are from
4
  # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5
  default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6
- default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6}
 
 
 
 
7
 
8
 
9
  def get_default_atol(output) -> float:
 
3
  # Reference default values of atol and rtol are from
4
  # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67
5
  default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5}
6
+ default_rtol = {
7
+ torch.float16: 1e-3,
8
+ torch.bfloat16: 1.6e-2,
9
+ torch.float: 1.3e-6
10
+ }
11
 
12
 
13
  def get_default_atol(output) -> float:
tests/kernels/test_poly_norm.py CHANGED
@@ -13,23 +13,20 @@ DTYPES = [torch.float, torch.bfloat16, torch.half]
13
  NUM_TOKENS = [7, 13] # Arbitrary values for testing
14
  D = [513] # Arbitrary values for testing
15
  SEEDS = [0]
16
- CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
17
 
18
 
19
  def norm(x, eps: float) -> torch.Tensor:
20
  return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
21
 
22
 
23
- def poly_norm(
24
- x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float
25
- ) -> torch.Tensor:
26
  x = x.float()
27
- return (
28
- weight[0] * norm(x**3, eps)
29
- + weight[1] * norm(x**2, eps)
30
- + weight[2] * norm(x, eps)
31
- + bias
32
- ).to(weight.dtype)
33
 
34
 
35
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 
13
  NUM_TOKENS = [7, 13] # Arbitrary values for testing
14
  D = [513] # Arbitrary values for testing
15
  SEEDS = [0]
16
+ CUDA_DEVICES = [
17
+ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
18
+ ]
19
 
20
 
21
  def norm(x, eps: float) -> torch.Tensor:
22
  return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
23
 
24
 
25
+ def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
26
+ eps: float) -> torch.Tensor:
 
27
  x = x.float()
28
+ return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) +
29
+ weight[2] * norm(x, eps) + bias).to(weight.dtype)
 
 
 
 
30
 
31
 
32
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
tests/kernels/test_poly_norm_perf.py CHANGED
@@ -94,7 +94,8 @@ def test_poly_norm(
94
  return start.elapsed_time(end) / NUM_REP
95
 
96
  kernel_time_ms = time_cuda(lambda: layer(x))
97
- torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
 
98
 
99
  PERF_RESULTS.append(
100
  PerfResult(
@@ -103,11 +104,12 @@ def test_poly_norm(
103
  dtype=dtype,
104
  kernel_time_ms=kernel_time_ms,
105
  torch_time_ms=torch_fn_time,
106
- )
107
- )
108
 
109
- kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True))
110
- torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True))
 
 
111
 
112
  PERF_RESULTS.append(
113
  PerfResult(
@@ -116,5 +118,4 @@ def test_poly_norm(
116
  dtype=dtype,
117
  kernel_time_ms=kernel_time_ms,
118
  torch_time_ms=torch_fn_time,
119
- )
120
- )
 
94
  return start.elapsed_time(end) / NUM_REP
95
 
96
  kernel_time_ms = time_cuda(lambda: layer(x))
97
+ torch_fn_time = time_cuda(
98
+ lambda: torch_fn(x_ref, weight_ref, bias_ref, eps))
99
 
100
  PERF_RESULTS.append(
101
  PerfResult(
 
104
  dtype=dtype,
105
  kernel_time_ms=kernel_time_ms,
106
  torch_time_ms=torch_fn_time,
107
+ ))
 
108
 
109
+ kernel_time_ms = time_cuda(
110
+ lambda: mod_out.backward(out_grad, retain_graph=True))
111
+ torch_fn_time = time_cuda(
112
+ lambda: ref_out.backward(out_grad, retain_graph=True))
113
 
114
  PERF_RESULTS.append(
115
  PerfResult(
 
118
  dtype=dtype,
119
  kernel_time_ms=kernel_time_ms,
120
  torch_time_ms=torch_fn_time,
121
+ ))
 
tests/kernels/test_rms_norm.py CHANGED
@@ -13,7 +13,9 @@ DTYPES = [torch.float, torch.bfloat16, torch.half]
13
  NUM_TOKENS = [7, 13] # Arbitrary values for testing
14
  D = [513] # Arbitrary values for testing
15
  SEEDS = [0]
16
- CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
 
 
17
 
18
 
19
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 
13
  NUM_TOKENS = [7, 13] # Arbitrary values for testing
14
  D = [513] # Arbitrary values for testing
15
  SEEDS = [0]
16
+ CUDA_DEVICES = [
17
+ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
18
+ ]
19
 
20
 
21
  @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
tests/kernels/utils.py CHANGED
@@ -46,15 +46,19 @@ def fp8_allclose(
46
  """
47
  Reference implementation of torch.allclose
48
  """
49
- torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol)
 
 
 
 
50
 
51
  return bool(
52
  torch.all(
53
- torch.isclose(
54
- a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan
55
- )
56
- ).item()
57
- )
58
 
59
 
60
  # A special version of op check that has a restricted default set of test_utils
@@ -73,10 +77,9 @@ def opcheck(
73
  cond: bool = True,
74
  ) -> Dict[str, str]:
75
  with unittest.mock.patch("torch.allclose", new=fp8_allclose):
76
- return (
77
- torch.library.opcheck(
78
- op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
79
- )
80
- if cond
81
- else {}
82
- )
 
46
  """
47
  Reference implementation of torch.allclose
48
  """
49
+ torch._refs._check_close_args(name="torch.allclose",
50
+ a=a,
51
+ b=b,
52
+ rtol=rtol,
53
+ atol=atol)
54
 
55
  return bool(
56
  torch.all(
57
+ torch.isclose(a.double(),
58
+ b.double(),
59
+ rtol=rtol,
60
+ atol=atol,
61
+ equal_nan=equal_nan)).item())
62
 
63
 
64
  # A special version of op check that has a restricted default set of test_utils
 
77
  cond: bool = True,
78
  ) -> Dict[str, str]:
79
  with unittest.mock.patch("torch.allclose", new=fp8_allclose):
80
+ return (torch.library.opcheck(op,
81
+ args,
82
+ kwargs,
83
+ test_utils=test_utils,
84
+ raise_exception=raise_exception)
85
+ if cond else {})
 
torch-ext/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
torch-ext/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
torch-ext/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
torch-ext/torch_binding.cpp CHANGED
@@ -5,18 +5,21 @@
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
  // Activation ops
8
- ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float eps) -> ()");
9
- ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
 
 
 
10
  ops.impl("poly_norm", torch::kCUDA, &poly_norm);
11
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
12
 
13
  // Activation ops
14
- ops.def("rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
15
- ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()");
 
 
16
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
17
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
18
-
19
-
20
  }
21
 
22
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
 
5
 
6
  TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
7
  // Activation ops
8
+ ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, "
9
+ "float eps) -> ()");
10
+ ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! "
11
+ "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float "
12
+ "eps) -> ()");
13
  ops.impl("poly_norm", torch::kCUDA, &poly_norm);
14
  ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward);
15
 
16
  // Activation ops
17
+ ops.def(
18
+ "rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()");
19
+ ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor "
20
+ "output_grad, Tensor input, Tensor weight, float eps) -> ()");
21
  ops.impl("rms_norm", torch::kCUDA, &rms_norm);
22
  ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward);
 
 
23
  }
24
 
25
  REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h CHANGED
@@ -2,8 +2,18 @@
2
 
3
  #include <torch/torch.h>
4
 
5
- void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps);
6
- void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
 
 
 
 
 
 
7
 
8
- void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, double eps);
9
- void rms_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps);
 
 
 
 
 
2
 
3
  #include <torch/torch.h>
4
 
5
+ void poly_norm(torch::Tensor &out, const torch::Tensor &input,
6
+ const torch::Tensor &weights, const torch::Tensor &bias,
7
+ double eps);
8
+ void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
9
+ torch::Tensor &bias_grad,
10
+ const torch::Tensor &output_grad,
11
+ const torch::Tensor &input, const torch::Tensor &weight,
12
+ double eps);
13
 
14
+ void rms_norm(torch::Tensor &out, const torch::Tensor &input,
15
+ const torch::Tensor &weights, double eps);
16
+ void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
17
+ const torch::Tensor &output_grad,
18
+ const torch::Tensor &input, const torch::Tensor &weight,
19
+ double eps);