This view is limited to 50 files because it contains too many changes.  See the raw diff here.
Files changed (50) hide show
  1. .pre-commit-config.yaml +36 -0
  2. README.md +45 -0
  3. activation/assert_utils.h +9 -5
  4. activation/atomic_utils.h +38 -31
  5. activation/block_reduce.h +3 -2
  6. activation/cuda_compat.h +9 -7
  7. activation/dispatch_utils.h +6 -5
  8. activation/poly_norm.cu +465 -156
  9. activation/rms_norm.cu +42 -59
  10. build/torch27-cxx11-cu118-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} +2 -2
  11. build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +3 -3
  12. build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py +2 -0
  13. build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py +9 -11
  14. build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +6 -3
  15. build/torch27-cxx11-cu126-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} +2 -2
  16. build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
  17. build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py +2 -0
  18. build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py +9 -11
  19. build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +6 -3
  20. build/torch27-cxx11-cu128-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} +2 -2
  21. build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +3 -3
  22. build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py +2 -0
  23. build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py +9 -11
  24. build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +6 -3
  25. build/torch27-cxx11-rocm63-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} +2 -2
  26. build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
  27. build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +2 -0
  28. build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py +9 -11
  29. build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +6 -3
  30. build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +0 -3
  31. build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +3 -0
  32. build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
  33. build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py +2 -0
  34. build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py +9 -11
  35. build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +6 -3
  36. build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +0 -3
  37. build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +3 -0
  38. build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py +3 -3
  39. build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py +2 -0
  40. build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py +9 -11
  41. build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py +6 -3
  42. build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +0 -3
  43. build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +3 -0
  44. build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py +3 -3
  45. build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py +2 -0
  46. build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py +9 -11
  47. build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py +6 -3
  48. build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +0 -3
  49. build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so +3 -0
  50. build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
.pre-commit-config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_install_hook_types:
2
+ - pre-commit
3
+ - commit-msg
4
+ default_stages:
5
+ - pre-commit # Run locally
6
+ - manual # Run in CI
7
+ exclude: '(build|result)/.*'
8
+ repos:
9
+ - repo: https://github.com/google/yapf
10
+ rev: v0.43.0
11
+ hooks:
12
+ - id: yapf
13
+ args: [--in-place, --verbose]
14
+ - repo: https://github.com/crate-ci/typos
15
+ rev: v1.34.0
16
+ hooks:
17
+ - id: typos
18
+ - repo: https://github.com/PyCQA/isort
19
+ rev: 6.0.1
20
+ hooks:
21
+ - id: isort
22
+ - repo: https://github.com/pre-commit/mirrors-clang-format
23
+ rev: v20.1.3
24
+ hooks:
25
+ - id: clang-format
26
+ types_or: [c++, cuda]
27
+ args: [--style=file, --verbose]
28
+ - repo: https://github.com/jackdewinter/pymarkdown
29
+ rev: v0.9.29
30
+ hooks:
31
+ - id: pymarkdown
32
+ args: [fix]
33
+ - repo: https://github.com/rhysd/actionlint
34
+ rev: v1.7.7
35
+ hooks:
36
+ - id: actionlint
README.md CHANGED
@@ -32,6 +32,7 @@ print(poly_norm(x))
32
 
33
  - Test cases are from the Motif LLM
34
  - You can reproduce the results with:
 
35
  ```bash
36
  cd tests
37
  pytest --run-perf --do-plot
@@ -39,3 +40,47 @@ pytest --run-perf --do-plot
39
 
40
  ![PolyNorm Performance](./tests/perf.png)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  - Test cases are from the Motif LLM
34
  - You can reproduce the results with:
35
+
36
  ```bash
37
  cd tests
38
  pytest --run-perf --do-plot
 
40
 
41
  ![PolyNorm Performance](./tests/perf.png)
42
 
43
+ ## Pre-commit Hooks
44
+
45
+ This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits.
46
+
47
+ ### Setup
48
+
49
+ 1. Install pre-commit:
50
+
51
+ ```bash
52
+ pip install pre-commit
53
+ ```
54
+
55
+ 2. Install the git hooks:
56
+
57
+ ```bash
58
+ pre-commit install
59
+ ```
60
+
61
+ Once installed, the configured hooks will run automatically on each commit.
62
+
63
+ ### Included Hooks
64
+
65
+ The following tools are run via pre-commit:
66
+
67
+ - **[yapf](https://github.com/google/yapf)** – Python code formatter
68
+ - **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos
69
+ - **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports
70
+ - **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`)
71
+ - **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files
72
+ - **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows
73
+
74
+ ### Usage
75
+
76
+ - Run all checks on the entire codebase:
77
+
78
+ ```bash
79
+ pre-commit run --all-files
80
+ ```
81
+
82
+ - Run a specific hook (example: isort):
83
+
84
+ ```bash
85
+ pre-commit run isort --all-files
86
+ ```
activation/assert_utils.h CHANGED
@@ -3,12 +3,15 @@
3
  #include <ATen/cuda/CUDAContext.h>
4
  #include <torch/all.h>
5
 
6
- inline void AssertTensorNotNull(const torch::Tensor &tensor, const std::string &name) {
 
7
  TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
8
  }
9
 
10
- inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::Tensor &tensor_b,
11
- const std::string &name_a, const std::string &name_b) {
 
 
12
 
13
  AssertTensorNotNull(tensor_a, name_a);
14
  AssertTensorNotNull(tensor_b, name_b);
@@ -17,6 +20,7 @@ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::T
17
  auto tensor_shape_b = tensor_b.sizes();
18
 
19
  TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
20
- "{} tensor shape should be equal to {} tensor shape. (actual: {}, expected: {})",
21
- name_a, name_b, tensor_shape_a, tensor_shape_b);
 
22
  }
 
3
  #include <ATen/cuda/CUDAContext.h>
4
  #include <torch/all.h>
5
 
6
+ inline void AssertTensorNotNull(const torch::Tensor &tensor,
7
+ const std::string &name) {
8
  TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null.");
9
  }
10
 
11
+ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a,
12
+ const torch::Tensor &tensor_b,
13
+ const std::string &name_a,
14
+ const std::string &name_b) {
15
 
16
  AssertTensorNotNull(tensor_a, name_a);
17
  AssertTensorNotNull(tensor_b, name_b);
 
20
  auto tensor_shape_b = tensor_b.sizes();
21
 
22
  TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b),
23
+ "{} tensor shape should be equal to {} tensor shape. "
24
+ "(actual: {}, expected: {})",
25
+ name_a, name_b, tensor_shape_a, tensor_shape_b);
26
  }
activation/atomic_utils.h CHANGED
@@ -1,35 +1,38 @@
1
  #pragma once
2
 
3
- #include <cuda.h>
4
  #include <c10/util/BFloat16.h>
5
  #include <c10/util/Half.h>
 
6
 
7
  namespace motif {
8
- template<typename scalar_t, typename acc_t>
9
- __device__ inline void atomic_add(scalar_t* address, acc_t value) {
10
  // TODO: change assert to a static_assert if possible
11
- assert(false && "Unsupported type for atomic_add");
12
  }
13
 
14
- template<>
15
- __device__ inline void atomic_add<float, float>(float* address, float value) {
16
- atomicAdd(address, value);
17
  }
18
 
19
- template<>
20
- __device__ inline void atomic_add<double, double>(double* address, double value) {
21
- atomicAdd(address, value);
 
22
  }
23
 
24
- template<>
25
- __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address, float value) {
26
- volatile c10::BFloat16* address = const_cast<volatile c10::BFloat16*>(_address);
 
 
27
 
28
  size_t offset = (size_t)address & 0x2;
29
- volatile uint16_t* address_as_short =
30
- reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
31
- volatile uint32_t* address_as_uint =
32
- reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
33
  bool is_32bit_aligned = offset == 0;
34
 
35
  uint32_t current = address_as_uint[0];
@@ -39,21 +42,24 @@ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address,
39
  expected = current;
40
  c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
41
  c10::BFloat16 next_bf16 = current_bf16 + value;
42
- uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
43
- : (current & 0x0000ffff) | (next_bf16.x << 16);
44
- current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
 
 
45
  } while (current != expected);
46
  }
47
 
48
- template<>
49
- __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float value) {
50
- volatile c10::Half* address = const_cast<volatile c10::Half*>(_address);
 
51
 
52
  size_t offset = (size_t)address & 0x2;
53
- volatile uint16_t* address_as_short =
54
- reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
55
- volatile uint32_t* address_as_uint =
56
- reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
57
  bool is_32bit_aligned = offset == 0;
58
 
59
  uint32_t current = address_as_uint[0];
@@ -63,11 +69,12 @@ __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float v
63
  expected = current;
64
  c10::Half current_half(address_as_short[0], c10::Half::from_bits());
65
  c10::Half next_half = current_half + value;
66
- uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
67
- : (current & 0x0000ffff) | (next_half.x << 16);
68
- current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
 
 
69
  } while (current != expected);
70
-
71
  }
72
 
73
  } // namespace motif
 
1
  #pragma once
2
 
 
3
  #include <c10/util/BFloat16.h>
4
  #include <c10/util/Half.h>
5
+ #include <cuda.h>
6
 
7
  namespace motif {
8
+ template <typename scalar_t, typename acc_t>
9
+ __device__ inline void atomic_add(scalar_t *address, acc_t value) {
10
  // TODO: change assert to a static_assert if possible
11
+ assert(false && "Unsupported type for atomic_add");
12
  }
13
 
14
+ template <>
15
+ __device__ inline void atomic_add<float, float>(float *address, float value) {
16
+ atomicAdd(address, value);
17
  }
18
 
19
+ template <>
20
+ __device__ inline void atomic_add<double, double>(double *address,
21
+ double value) {
22
+ atomicAdd(address, value);
23
  }
24
 
25
+ template <>
26
+ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16 *_address,
27
+ float value) {
28
+ volatile c10::BFloat16 *address =
29
+ const_cast<volatile c10::BFloat16 *>(_address);
30
 
31
  size_t offset = (size_t)address & 0x2;
32
+ volatile uint16_t *address_as_short = reinterpret_cast<volatile uint16_t *>(
33
+ reinterpret_cast<volatile char *>(address));
34
+ volatile uint32_t *address_as_uint = reinterpret_cast<volatile uint *>(
35
+ reinterpret_cast<volatile char *>(address) - offset);
36
  bool is_32bit_aligned = offset == 0;
37
 
38
  uint32_t current = address_as_uint[0];
 
42
  expected = current;
43
  c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
44
  c10::BFloat16 next_bf16 = current_bf16 + value;
45
+ uint32_t next = is_32bit_aligned
46
+ ? (current & 0xffff0000) | next_bf16.x
47
+ : (current & 0x0000ffff) | (next_bf16.x << 16);
48
+ current =
49
+ atomicCAS(const_cast<uint32_t *>(address_as_uint), expected, next);
50
  } while (current != expected);
51
  }
52
 
53
+ template <>
54
+ __device__ inline void atomic_add<c10::Half, float>(c10::Half *_address,
55
+ float value) {
56
+ volatile c10::Half *address = const_cast<volatile c10::Half *>(_address);
57
 
58
  size_t offset = (size_t)address & 0x2;
59
+ volatile uint16_t *address_as_short = reinterpret_cast<volatile uint16_t *>(
60
+ reinterpret_cast<volatile char *>(address));
61
+ volatile uint32_t *address_as_uint = reinterpret_cast<volatile uint *>(
62
+ reinterpret_cast<volatile char *>(address) - offset);
63
  bool is_32bit_aligned = offset == 0;
64
 
65
  uint32_t current = address_as_uint[0];
 
69
  expected = current;
70
  c10::Half current_half(address_as_short[0], c10::Half::from_bits());
71
  c10::Half next_half = current_half + value;
72
+ uint32_t next = is_32bit_aligned
73
+ ? (current & 0xffff0000) | next_half.x
74
+ : (current & 0x0000ffff) | (next_half.x << 16);
75
+ current =
76
+ atomicCAS(const_cast<uint32_t *>(address_as_uint), expected, next);
77
  } while (current != expected);
 
78
  }
79
 
80
  } // namespace motif
activation/block_reduce.h CHANGED
@@ -1,7 +1,8 @@
1
  namespace motif {
2
 
3
  template <typename acc_t, int BLOCK_SIZE>
4
- __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) {
 
5
  // TODO: Optimize with warp-level primitives
6
  __syncthreads();
7
 
@@ -17,4 +18,4 @@ __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d)
17
  return shared[0];
18
  }
19
 
20
- } // motif
 
1
  namespace motif {
2
 
3
  template <typename acc_t, int BLOCK_SIZE>
4
+ __device__ acc_t _block_reduce_sum(acc_t *shared, const float val,
5
+ const int d) {
6
  // TODO: Optimize with warp-level primitives
7
  __syncthreads();
8
 
 
18
  return shared[0];
19
  }
20
 
21
+ } // namespace motif
activation/cuda_compat.h CHANGED
@@ -1,18 +1,20 @@
1
  #pragma once
2
 
3
- #ifdef USE_ROCM
4
- #include <hip/hip_runtime.h>
 
 
 
5
  #endif
6
 
7
  #ifndef USE_ROCM
8
- #define WARP_SIZE 32
9
  #else
10
- #define WARP_SIZE warpSize
11
  #endif
12
 
13
  #ifndef USE_ROCM
14
- #define VLLM_LDG(arg) __ldg(arg)
15
  #else
16
- #define VLLM_LDG(arg) *(arg)
17
  #endif
18
-
 
1
  #pragma once
2
 
3
+ #ifndef USE_ROCM
4
+ #include <cub/cub.cuh>
5
+ #else
6
+ #include <hip/hip_runtime.h>
7
+ #include <hipcub/hipcub.hpp>
8
  #endif
9
 
10
  #ifndef USE_ROCM
11
+ #define WARP_SIZE 32
12
  #else
13
+ #define WARP_SIZE warpSize
14
  #endif
15
 
16
  #ifndef USE_ROCM
17
+ #define VLLM_LDG(arg) __ldg(arg)
18
  #else
19
+ #define VLLM_LDG(arg) *(arg)
20
  #endif
 
activation/dispatch_utils.h CHANGED
@@ -6,10 +6,11 @@
6
 
7
  #include <torch/all.h>
8
 
9
- #define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
10
- AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
- AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
 
14
- #define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
- AT_DISPATCH_SWITCH(TYPE, NAME, MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 
 
6
 
7
  #include <torch/all.h>
8
 
9
+ #define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \
10
+ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
11
+ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
12
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
13
 
14
+ #define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
15
+ AT_DISPATCH_SWITCH(TYPE, NAME, \
16
+ MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
activation/poly_norm.cu CHANGED
@@ -1,246 +1,555 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
  #include <ATen/Functions.h>
3
- #include <torch/all.h>
4
  #include <c10/cuda/CUDAGuard.h>
 
5
 
6
  #include <cmath>
7
 
8
- #include "cuda_compat.h"
9
- #include "dispatch_utils.h"
10
  #include "assert_utils.h"
11
  #include "atomic_utils.h"
12
  #include "block_reduce.h"
 
 
13
 
14
  namespace motif {
15
 
16
- template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
- __global__ void poly_norm_kernel(
18
- scalar_t* __restrict__ out, // [..., d]
19
- const scalar_t* __restrict__ input, // [..., d]
20
- const scalar_t* __restrict__ weight, // [3]
21
- const scalar_t* __restrict__ bias, // [1]
22
- const float eps,
23
- const int d
24
- ) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  const int64_t token_idx = blockIdx.x;
26
 
27
- acc_t sum = 0.0f;
28
- acc_t sum_square = 0.0f;
29
- acc_t sum_cube = 0.0f;
30
 
31
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
32
- acc_t x = input[token_idx * d + idx];
33
- sum += pow(x, 2.0f);
34
- sum_square += pow(x, 4.0f);
35
- sum_cube += pow(x, 6.0f);
 
 
 
 
36
  }
37
 
38
- __shared__ acc_t shared[BLOCK_SIZE];
 
39
 
40
- acc_t mean = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum, d) / d;
41
- acc_t mean_square = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
42
- acc_t mean_cube = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_cube, d) / d;
 
 
43
 
44
- acc_t w0 = weight[0];
45
- acc_t w1 = weight[1];
46
- acc_t w2 = weight[2];
47
- acc_t b = bias[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- acc_t divisor = sqrt(mean + eps);
50
- acc_t divisor_square = sqrt(mean_square + eps);
51
- acc_t divisor_cube = sqrt(mean_cube + eps);
 
52
 
53
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
54
- acc_t x = input[token_idx * d + idx];
55
- acc_t x_square = pow(x, 2.0f);
56
- acc_t x_cube = pow(x, 3.0f);
57
- out[token_idx * d + idx] = w2 * x / divisor +
58
- w1 * x_square / divisor_square +
59
- w0 * x_cube / divisor_cube + b;
60
  }
61
  }
62
 
63
- template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
64
- __global__ void poly_norm_backward_kernel(
65
- scalar_t* __restrict__ input_grad, // [..., d]
66
- acc_t* __restrict__ temp_weight_grad, // [..., 3]
67
- const scalar_t* __restrict__ output_grad, // [..., d]
68
- const scalar_t* __restrict__ input, // [..., d]
69
- const scalar_t* __restrict__ weight, // [3]
70
- const float eps,
71
- const int d
72
- ) {
73
- const int64_t token_idx = blockIdx.x;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  acc_t w0 = weight[0];
76
  acc_t w1 = weight[1];
77
  acc_t w2 = weight[2];
78
 
79
- acc_t sum_2 = 0.0f;
80
- acc_t sum_4 = 0.0f;
81
- acc_t sum_6 = 0.0f;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- acc_t sum_dx_1 = 0.0f;
84
- acc_t sum_dx_2 = 0.0f;
85
- acc_t sum_dx_3 = 0.0f;
86
 
87
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
88
  acc_t dy = output_grad[token_idx * d + idx];
89
 
90
- acc_t x_1 = input[token_idx * d + idx];
91
- acc_t x_2 = x_1 * x_1;
92
- acc_t x_3 = x_2 * x_1;
93
- acc_t x_4 = x_2 * x_2;
94
- acc_t x_6 = x_3 * x_3;
95
 
96
- sum_2 += x_2;
97
- sum_4 += x_4;
98
- sum_6 += x_6;
99
 
100
- sum_dx_1 += dy * x_1;
101
- sum_dx_2 += dy * x_2;
102
- sum_dx_3 += dy * x_3;
103
  }
104
 
105
- __shared__ acc_t shared[BLOCK_SIZE];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- acc_t mean_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_2, d) / d + eps;
108
- acc_t mean_4 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_4, d) / d + eps;
109
- acc_t mean_6 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_6, d) / d + eps;
110
 
111
- sum_dx_1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_1, d);
112
- sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
113
- sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
 
 
 
114
 
115
- acc_t _mean_2 = powf(mean_2, -1.5);
116
- acc_t _mean_4 = powf(mean_4, -1.5);
117
- acc_t _mean_6 = powf(mean_6, -1.5);
118
 
119
- acc_t sq_mean_2 = sqrtf(mean_2);
120
- acc_t sq_mean_4 = sqrtf(mean_4);
121
- acc_t sq_mean_6 = sqrtf(mean_6);
 
122
 
 
123
  acc_t sum_dw0 = 0;
124
  acc_t sum_dw1 = 0;
125
  acc_t sum_dw2 = 0;
126
 
127
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
128
  acc_t dy = output_grad[token_idx * d + idx];
129
- acc_t x_1 = input[token_idx * d + idx];
130
- acc_t x_2 = x_1 * x_1;
131
- acc_t x_3 = x_2 * x_1;
132
-
133
- acc_t dx_3 =
134
- _mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
135
- acc_t dx_2 =
136
- _mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
137
- acc_t dx_1 =
138
- _mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
139
 
140
  if (input_grad) {
141
- input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
 
 
 
142
  }
143
 
144
- sum_dw0 += dy * (x_3 / sq_mean_6);
145
- sum_dw1 += dy * (x_2 / sq_mean_4);
146
- sum_dw2 += dy * (x_1 / sq_mean_2);
 
147
  }
148
 
149
- if (temp_weight_grad) {
150
- sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
151
- sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
152
- sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
153
-
154
- if (threadIdx.x == 0) {
155
- temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
156
- temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
157
- temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
158
- }
 
 
 
159
  }
160
  }
161
 
162
- } // namespace motif
163
-
164
-
165
- void poly_norm(torch::Tensor& out, // [..., d]
166
- const torch::Tensor& input, // [..., d]
167
- const torch::Tensor& weight, // [3]
168
- const torch::Tensor& bias, // [1]
169
- double eps)
170
- {
 
 
 
 
 
 
171
  AssertTensorShapeEqual(input, out, "input", "out");
172
  AssertTensorNotNull(weight, "weight");
173
  AssertTensorNotNull(bias, "bias");
174
  // TODO shape check
175
 
176
- constexpr int BLOCK_SIZE = 256;
177
-
178
  int d = input.size(-1);
179
- int64_t num_tokens = input.numel() / input.size(-1);
180
  dim3 grid(num_tokens);
181
- dim3 block(BLOCK_SIZE);
 
182
 
183
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
184
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
185
- MOTIF_DISPATCH_FLOATING_TYPES(
186
- input.scalar_type(), "poly_norm_kernel", [&] {
187
- motif::poly_norm_kernel<scalar_t, float, BLOCK_SIZE>
188
- <<<grid, block, 0, stream>>>(
189
- out.data_ptr<scalar_t>(),
190
- input.data_ptr<scalar_t>(),
191
- weight.data_ptr<scalar_t>(),
192
- bias.data_ptr<scalar_t>(), eps, d);
193
- }
194
- );
195
  }
196
 
197
- void poly_norm_backward(
198
- torch::Tensor& input_grad, // [..., d]
199
- torch::Tensor& weight_grad, // [..., d]
200
- torch::Tensor& bias_grad, // [..., d]
201
- const torch::Tensor& output_grad, // [3]
202
- const torch::Tensor& input, // [3]
203
- const torch::Tensor& weight, // [3]
204
- double eps) {
 
 
 
 
 
 
 
 
 
 
 
205
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
206
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
207
  AssertTensorNotNull(weight, "weight");
208
  // TODO shape check
209
  // weight_grad, bias_grad and input_grad can be nullable
210
 
211
- constexpr int BLOCK_SIZE = 256;
212
-
213
  int d = input.size(-1);
214
- int64_t num_tokens = input.numel() / input.size(-1);
215
  dim3 grid(num_tokens);
216
- dim3 block(BLOCK_SIZE);
 
217
 
218
  torch::Tensor temp_weight_grad =
219
- torch::empty({num_tokens, 3},
220
- input.options().dtype(torch::kFloat));
221
-
222
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
223
 
224
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
225
- MOTIF_DISPATCH_FLOATING_TYPES(
226
- input.scalar_type(), "poly_norm_backward_kernel", [&] {
227
- motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
228
- <<<grid, block, 0, stream>>>(
229
- input_grad.data_ptr<scalar_t>(),
230
- temp_weight_grad.data_ptr<float>(),
231
- output_grad.data_ptr<scalar_t>(),
232
- input.data_ptr<scalar_t>(),
233
- weight.data_ptr<scalar_t>(),
234
- eps, d);
235
- }
236
- );
237
 
238
  if (bias_grad.defined()) {
239
- at::sum_out(bias_grad, output_grad);
240
- bias_grad.resize_({1});
 
241
  }
242
 
243
  if (weight_grad.defined()) {
244
- at::sum_out(weight_grad, temp_weight_grad, {0});
 
 
 
245
  }
246
  }
 
 
1
  #include <ATen/Functions.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
  #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/all.h>
5
 
6
  #include <cmath>
7
 
 
 
8
  #include "assert_utils.h"
9
  #include "atomic_utils.h"
10
  #include "block_reduce.h"
11
+ #include "cuda_compat.h"
12
+ #include "dispatch_utils.h"
13
 
14
  namespace motif {
15
 
16
+ template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
17
+ type data[N];
18
+ };
19
+
20
+ template <typename scalar_t, typename acc_t, int width>
21
+ __global__ std::enable_if_t<(width > 0)>
22
+ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
23
+ const scalar_t *__restrict__ input, // [..., d]
24
+ const scalar_t *__restrict__ weight, // [3]
25
+ const scalar_t *__restrict__ bias, // [1]
26
+ const float eps, const int d) {
27
+ using vec_t = type_vec_t<scalar_t, width>;
28
+
29
+ const int vec_d = d / width;
30
+ const int64_t vec_offset = blockIdx.x * vec_d;
31
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
32
+
33
+ acc_t sum2 = 0.0f;
34
+ acc_t sum4 = 0.0f;
35
+ acc_t sum6 = 0.0f;
36
+
37
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
38
+ vec_t x_vec = input_vec[vec_offset + idx];
39
+
40
+ #pragma unroll
41
+ for (int i = 0; i < width; ++i) {
42
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
43
+ acc_t x2 = x1 * x1;
44
+ acc_t x4 = x2 * x2;
45
+ acc_t x6 = x4 * x2;
46
+
47
+ sum2 += x2;
48
+ sum4 += x4;
49
+ sum6 += x6;
50
+ }
51
+ }
52
+
53
+ using BlockReduce = cub::BlockReduce<float, 1024>;
54
+ __shared__ typename BlockReduce::TempStorage reduceStore;
55
+
56
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
57
+ __syncthreads();
58
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
59
+ __syncthreads();
60
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
61
+
62
+ __shared__ acc_t s_bias;
63
+
64
+ __shared__ acc_t s_w2_inv_std1;
65
+ __shared__ acc_t s_w1_inv_std2;
66
+ __shared__ acc_t s_w0_inv_std3;
67
+
68
+ if (threadIdx.x == 0) {
69
+ acc_t w0 = weight[0];
70
+ acc_t w1 = weight[1];
71
+ acc_t w2 = weight[2];
72
+ s_bias = bias[0];
73
+
74
+ s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
75
+ s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
76
+ s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
77
+ }
78
+ __syncthreads();
79
+
80
+ acc_t w2_inv_std1 = s_w2_inv_std1;
81
+ acc_t w1_inv_std2 = s_w1_inv_std2;
82
+ acc_t w0_inv_std3 = s_w0_inv_std3;
83
+ acc_t bias_reg = s_bias;
84
+
85
+ vec_t *__restrict__ output_vec = reinterpret_cast<vec_t *>(out);
86
+
87
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
88
+ vec_t x_vec = input_vec[vec_offset + idx];
89
+ vec_t y_vec;
90
+
91
+ #pragma unroll
92
+ for (int i = 0; i < width; ++i) {
93
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
94
+ acc_t x2 = x1 * x1;
95
+ acc_t x3 = x2 * x1;
96
+
97
+ acc_t y =
98
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
99
+
100
+ y_vec.data[i] = static_cast<scalar_t>(y);
101
+ }
102
+ output_vec[vec_offset + idx] = y_vec;
103
+ }
104
+ }
105
+
106
+ template <typename scalar_t, typename acc_t, int width>
107
+ __global__ std::enable_if_t<(width == 0)>
108
+ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
109
+ const scalar_t *__restrict__ input, // [..., d]
110
+ const scalar_t *__restrict__ weight, // [3]
111
+ const scalar_t *__restrict__ bias, // [1]
112
+ const float eps, const int d) {
113
  const int64_t token_idx = blockIdx.x;
114
 
115
+ acc_t sum2 = 0.0f;
116
+ acc_t sum4 = 0.0f;
117
+ acc_t sum6 = 0.0f;
118
 
119
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
120
+ acc_t x1 = input[token_idx * d + idx];
121
+ acc_t x2 = x1 * x1;
122
+ acc_t x4 = x2 * x2;
123
+ acc_t x6 = x4 * x2;
124
+
125
+ sum2 += x2;
126
+ sum4 += x4;
127
+ sum6 += x6;
128
  }
129
 
130
+ using BlockReduce = cub::BlockReduce<float, 1024>;
131
+ __shared__ typename BlockReduce::TempStorage reduceStore;
132
 
133
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
134
+ __syncthreads();
135
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
136
+ __syncthreads();
137
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
138
 
139
+ __shared__ acc_t s_bias;
140
+
141
+ __shared__ acc_t s_w2_inv_std1;
142
+ __shared__ acc_t s_w1_inv_std2;
143
+ __shared__ acc_t s_w0_inv_std3;
144
+
145
+ if (threadIdx.x == 0) {
146
+ acc_t w0 = weight[0];
147
+ acc_t w1 = weight[1];
148
+ acc_t w2 = weight[2];
149
+ s_bias = bias[0];
150
+
151
+ s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
152
+ s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
153
+ s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
154
+ }
155
+ __syncthreads();
156
 
157
+ acc_t w2_inv_std1 = s_w2_inv_std1;
158
+ acc_t w1_inv_std2 = s_w1_inv_std2;
159
+ acc_t w0_inv_std3 = s_w0_inv_std3;
160
+ acc_t bias_reg = s_bias;
161
 
162
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
163
+ acc_t x1 = input[token_idx * d + idx];
164
+ acc_t x2 = x1 * x1;
165
+ acc_t x3 = x2 * x1;
166
+ out[token_idx * d + idx] =
167
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
 
168
  }
169
  }
170
 
171
+ template <typename scalar_t, typename acc_t, int width>
172
+ __global__ std::enable_if_t<(width > 0)>
173
+ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
174
+ acc_t *__restrict__ temp_weight_grad, // [..., 3]
175
+ acc_t *__restrict__ temp_bias_grad, // [..., 1]
176
+ const scalar_t *__restrict__ output_grad, // [..., d]
177
+ const scalar_t *__restrict__ input, // [..., d]
178
+ const scalar_t *__restrict__ weight, // [3]
179
+ const float eps, const int d) {
180
+ using vec_t = type_vec_t<scalar_t, width>;
181
+
182
+ const int vec_d = d / width;
183
+ const int64_t vec_offset = blockIdx.x * vec_d;
184
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
185
+ const vec_t *__restrict__ output_grad_vec =
186
+ reinterpret_cast<const vec_t *>(output_grad);
187
+
188
+ acc_t sum2 = 0.0f;
189
+ acc_t sum4 = 0.0f;
190
+ acc_t sum6 = 0.0f;
191
+
192
+ acc_t sum_dx1 = 0.0f;
193
+ acc_t sum_dx2 = 0.0f;
194
+ acc_t sum_dx3 = 0.0f;
195
+
196
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
197
+ vec_t x_vec = input_vec[vec_offset + idx];
198
+ vec_t dy_vec = output_grad_vec[vec_offset + idx];
199
+
200
+ #pragma unroll
201
+ for (int i = 0; i < width; ++i) {
202
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
203
+ acc_t x2 = x1 * x1;
204
+ acc_t x3 = x2 * x1;
205
+ acc_t x4 = x2 * x2;
206
+ acc_t x6 = x3 * x3;
207
+
208
+ sum2 += x2;
209
+ sum4 += x4;
210
+ sum6 += x6;
211
+
212
+ acc_t dy = static_cast<acc_t>(dy_vec.data[i]);
213
+
214
+ sum_dx1 += dy * x1;
215
+ sum_dx2 += dy * x2;
216
+ sum_dx3 += dy * x3;
217
+ }
218
+ }
219
+
220
+ using BlockReduce = cub::BlockReduce<float, 1024>;
221
+ __shared__ typename BlockReduce::TempStorage reduceStore;
222
+
223
+ __syncthreads();
224
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
225
+ __syncthreads();
226
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
227
+ __syncthreads();
228
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
229
+
230
+ __syncthreads();
231
+ sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x);
232
+ __syncthreads();
233
+ sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x);
234
+ __syncthreads();
235
+ sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x);
236
+
237
+ __shared__ acc_t s_mean2;
238
+ __shared__ acc_t s_mean4;
239
+ __shared__ acc_t s_mean6;
240
+ __shared__ acc_t s_sdx1;
241
+ __shared__ acc_t s_sdx2;
242
+ __shared__ acc_t s_sdx3;
243
+
244
+ const acc_t inv_d = acc_t(1) / d;
245
+
246
+ if (threadIdx.x == 0) {
247
+ s_mean2 = sum2 * inv_d + eps;
248
+ s_mean4 = sum4 * inv_d + eps;
249
+ s_mean6 = sum6 * inv_d + eps;
250
+
251
+ s_sdx1 = sum_dx1 * inv_d;
252
+ s_sdx2 = sum_dx2 * inv_d;
253
+ s_sdx3 = sum_dx3 * inv_d;
254
+ }
255
+ __syncthreads();
256
 
257
  acc_t w0 = weight[0];
258
  acc_t w1 = weight[1];
259
  acc_t w2 = weight[2];
260
 
261
+ acc_t mean2 = s_mean2;
262
+ acc_t mean4 = s_mean4;
263
+ acc_t mean6 = s_mean6;
264
+ acc_t sdx1 = s_sdx1;
265
+ acc_t sdx2 = s_sdx2;
266
+ acc_t sdx3 = s_sdx3;
267
+
268
+ acc_t inv_std1 = rsqrtf(mean2);
269
+ acc_t inv_std2 = rsqrtf(mean4);
270
+ acc_t inv_std3 = rsqrtf(mean6);
271
+
272
+ // inv_std / mean == powf(mean, -1.5)
273
+ acc_t c1 = w2 * inv_std1 / mean2;
274
+ acc_t c2 = acc_t(2) * w1 * inv_std2 / mean4;
275
+ acc_t c3 = acc_t(3) * w0 * inv_std3 / mean6;
276
+
277
+ acc_t sum_dy = 0;
278
+ acc_t sum_dw0 = 0;
279
+ acc_t sum_dw1 = 0;
280
+ acc_t sum_dw2 = 0;
281
+
282
+ vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
283
+
284
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
285
+ vec_t x_vec = input_vec[vec_offset + idx];
286
+ vec_t dy_vec = output_grad_vec[vec_offset + idx];
287
+ vec_t dx_vec;
288
+
289
+ #pragma unroll
290
+ for (int i = 0; i < width; ++i) {
291
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
292
+ acc_t x2 = x1 * x1;
293
+ acc_t x3 = x2 * x1;
294
+ acc_t dy = static_cast<acc_t>(dy_vec.data[i]);
295
+
296
+ if (input_grad) {
297
+ acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
298
+ acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
299
+ acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
300
+ dx_vec.data[i] = static_cast<scalar_t>(dx1 + dx2 + dx3);
301
+ }
302
+
303
+ sum_dy += dy;
304
+ sum_dw0 += dy * (x3 * inv_std3);
305
+ sum_dw1 += dy * (x2 * inv_std2);
306
+ sum_dw2 += dy * (x1 * inv_std1);
307
+ }
308
+
309
+ if (input_grad) {
310
+ input_grad_vec[vec_offset + idx] = dx_vec;
311
+ }
312
+ }
313
+
314
+ sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x);
315
+ __syncthreads();
316
+ sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x);
317
+ __syncthreads();
318
+ sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x);
319
+ __syncthreads();
320
+ sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x);
321
+
322
+ if (threadIdx.x == 0) {
323
+ temp_bias_grad[blockIdx.x] = sum_dy;
324
+ temp_weight_grad[blockIdx.x * 3 + 0] = sum_dw0;
325
+ temp_weight_grad[blockIdx.x * 3 + 1] = sum_dw1;
326
+ temp_weight_grad[blockIdx.x * 3 + 2] = sum_dw2;
327
+ }
328
+ }
329
+
330
+ template <typename scalar_t, typename acc_t, int width>
331
+ __global__ std::enable_if_t<(width == 0)>
332
+ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
333
+ acc_t *__restrict__ temp_weight_grad, // [..., 3]
334
+ acc_t *__restrict__ temp_bias_grad, // [..., 1]
335
+ const scalar_t *__restrict__ output_grad, // [..., d]
336
+ const scalar_t *__restrict__ input, // [..., d]
337
+ const scalar_t *__restrict__ weight, // [3]
338
+ const float eps, const int d) {
339
+ const int64_t token_idx = blockIdx.x;
340
+
341
+ acc_t sum2 = 0.0f;
342
+ acc_t sum4 = 0.0f;
343
+ acc_t sum6 = 0.0f;
344
 
345
+ acc_t sum_dx1 = 0.0f;
346
+ acc_t sum_dx2 = 0.0f;
347
+ acc_t sum_dx3 = 0.0f;
348
 
349
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
350
  acc_t dy = output_grad[token_idx * d + idx];
351
 
352
+ acc_t x1 = input[token_idx * d + idx];
353
+ acc_t x2 = x1 * x1;
354
+ acc_t x3 = x2 * x1;
355
+ acc_t x4 = x2 * x2;
356
+ acc_t x6 = x3 * x3;
357
 
358
+ sum2 += x2;
359
+ sum4 += x4;
360
+ sum6 += x6;
361
 
362
+ sum_dx1 += dy * x1;
363
+ sum_dx2 += dy * x2;
364
+ sum_dx3 += dy * x3;
365
  }
366
 
367
+ using BlockReduce = cub::BlockReduce<float, 1024>;
368
+ __shared__ typename BlockReduce::TempStorage reduceStore;
369
+
370
+ __syncthreads();
371
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
372
+ __syncthreads();
373
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
374
+ __syncthreads();
375
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
376
+
377
+ __syncthreads();
378
+ sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x);
379
+ __syncthreads();
380
+ sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x);
381
+ __syncthreads();
382
+ sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x);
383
+
384
+ __shared__ acc_t s_mean2;
385
+ __shared__ acc_t s_mean4;
386
+ __shared__ acc_t s_mean6;
387
+ __shared__ acc_t s_sdx1;
388
+ __shared__ acc_t s_sdx2;
389
+ __shared__ acc_t s_sdx3;
390
+
391
+ const acc_t inv_d = acc_t(1) / d;
392
+
393
+ if (threadIdx.x == 0) {
394
+ s_mean2 = sum2 * inv_d + eps;
395
+ s_mean4 = sum4 * inv_d + eps;
396
+ s_mean6 = sum6 * inv_d + eps;
397
+
398
+ s_sdx1 = sum_dx1 * inv_d;
399
+ s_sdx2 = sum_dx2 * inv_d;
400
+ s_sdx3 = sum_dx3 * inv_d;
401
+ }
402
+ __syncthreads();
403
 
404
+ acc_t w0 = weight[0];
405
+ acc_t w1 = weight[1];
406
+ acc_t w2 = weight[2];
407
 
408
+ acc_t mean2 = s_mean2;
409
+ acc_t mean4 = s_mean4;
410
+ acc_t mean6 = s_mean6;
411
+ acc_t sdx1 = s_sdx1;
412
+ acc_t sdx2 = s_sdx2;
413
+ acc_t sdx3 = s_sdx3;
414
 
415
+ acc_t inv_std1 = rsqrtf(mean2);
416
+ acc_t inv_std2 = rsqrtf(mean4);
417
+ acc_t inv_std3 = rsqrtf(mean6);
418
 
419
+ // inv_std / mean == powf(mean, -1.5)
420
+ acc_t c1 = w2 * inv_std1 / mean2;
421
+ acc_t c2 = acc_t(2) * w1 * inv_std2 / mean4;
422
+ acc_t c3 = acc_t(3) * w0 * inv_std3 / mean6;
423
 
424
+ acc_t sum_dy = 0;
425
  acc_t sum_dw0 = 0;
426
  acc_t sum_dw1 = 0;
427
  acc_t sum_dw2 = 0;
428
 
429
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
430
  acc_t dy = output_grad[token_idx * d + idx];
431
+ acc_t x1 = input[token_idx * d + idx];
432
+ acc_t x2 = x1 * x1;
433
+ acc_t x3 = x2 * x1;
 
 
 
 
 
 
 
434
 
435
  if (input_grad) {
436
+ acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
437
+ acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
438
+ acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
439
+ input_grad[token_idx * d + idx] = dx1 + dx2 + dx3;
440
  }
441
 
442
+ sum_dy += dy;
443
+ sum_dw0 += dy * (x3 * inv_std3);
444
+ sum_dw1 += dy * (x2 * inv_std2);
445
+ sum_dw2 += dy * (x1 * inv_std1);
446
  }
447
 
448
+ sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x);
449
+ __syncthreads();
450
+ sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x);
451
+ __syncthreads();
452
+ sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x);
453
+ __syncthreads();
454
+ sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x);
455
+
456
+ if (threadIdx.x == 0) {
457
+ temp_bias_grad[token_idx] = sum_dy;
458
+ temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
459
+ temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
460
+ temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
461
  }
462
  }
463
 
464
+ } // namespace motif
465
+
466
+ #define LAUNCH_POLY_NORM(width) \
467
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
468
+ motif::poly_norm_kernel<scalar_t, float, width> \
469
+ <<<grid, block, 0, stream>>>( \
470
+ out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
471
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), eps, d); \
472
+ });
473
+
474
+ void poly_norm(torch::Tensor &out, // [..., d]
475
+ const torch::Tensor &input, // [..., d]
476
+ const torch::Tensor &weight, // [3]
477
+ const torch::Tensor &bias, // [1]
478
+ double eps) {
479
  AssertTensorShapeEqual(input, out, "input", "out");
480
  AssertTensorNotNull(weight, "weight");
481
  AssertTensorNotNull(bias, "bias");
482
  // TODO shape check
483
 
 
 
484
  int d = input.size(-1);
485
+ int64_t num_tokens = input.numel() / d;
486
  dim3 grid(num_tokens);
487
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
488
+ dim3 block(std::min(d, max_block_size));
489
 
490
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
491
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
492
+ if (d % 8 == 0) {
493
+ LAUNCH_POLY_NORM(8);
494
+ } else {
495
+ LAUNCH_POLY_NORM(0);
496
+ }
 
 
 
 
 
497
  }
498
 
499
+ #define LAUNCH_POLY_NORM_BACKWARD(width) \
500
+ MOTIF_DISPATCH_FLOATING_TYPES( \
501
+ input.scalar_type(), "poly_norm_backward_kernel", [&] { \
502
+ motif::poly_norm_backward_kernel<scalar_t, float, width> \
503
+ <<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(), \
504
+ temp_weight_grad.data_ptr<float>(), \
505
+ temp_bias_grad.data_ptr<float>(), \
506
+ output_grad.data_ptr<scalar_t>(), \
507
+ input.data_ptr<scalar_t>(), \
508
+ weight.data_ptr<scalar_t>(), eps, d); \
509
+ });
510
+
511
+ void poly_norm_backward(torch::Tensor &input_grad, // [..., d]
512
+ torch::Tensor &weight_grad, // [3]
513
+ torch::Tensor &bias_grad, // [1]
514
+ const torch::Tensor &output_grad, // [..., d]
515
+ const torch::Tensor &input, // [..., d]
516
+ const torch::Tensor &weight, // [3]
517
+ double eps) {
518
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
519
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
520
  AssertTensorNotNull(weight, "weight");
521
  // TODO shape check
522
  // weight_grad, bias_grad and input_grad can be nullable
523
 
 
 
524
  int d = input.size(-1);
525
+ int64_t num_tokens = input.numel() / d;
526
  dim3 grid(num_tokens);
527
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
528
+ dim3 block(std::min(d, max_block_size));
529
 
530
  torch::Tensor temp_weight_grad =
531
+ torch::empty({num_tokens, 3}, input.options().dtype(torch::kFloat));
532
+ torch::Tensor temp_bias_grad =
533
+ torch::empty({num_tokens, 1}, output_grad.options().dtype(torch::kFloat));
 
534
 
535
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
536
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
537
+ if (d % 8 == 0) {
538
+ LAUNCH_POLY_NORM_BACKWARD(8);
539
+ } else {
540
+ LAUNCH_POLY_NORM_BACKWARD(0);
541
+ }
 
 
 
 
 
 
542
 
543
  if (bias_grad.defined()) {
544
+ torch::Tensor acc = torch::empty_like(bias_grad, temp_bias_grad.options());
545
+ at::sum_out(acc, temp_bias_grad, {0});
546
+ bias_grad.copy_(acc);
547
  }
548
 
549
  if (weight_grad.defined()) {
550
+ torch::Tensor acc =
551
+ torch::empty_like(weight_grad, temp_weight_grad.options());
552
+ at::sum_out(acc, temp_weight_grad, {0});
553
+ weight_grad.copy_(acc);
554
  }
555
  }
activation/rms_norm.cu CHANGED
@@ -1,26 +1,23 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
  #include <ATen/Functions.h>
3
- #include <torch/all.h>
4
  #include <c10/cuda/CUDAGuard.h>
 
5
 
6
  #include <cmath>
7
 
8
- #include "cuda_compat.h"
9
- #include "dispatch_utils.h"
10
  #include "assert_utils.h"
11
  #include "atomic_utils.h"
12
  #include "block_reduce.h"
 
 
13
 
14
  namespace motif {
15
 
16
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
- __global__ void rms_norm_kernel(
18
- scalar_t* __restrict__ out, // [..., d]
19
- const scalar_t* __restrict__ input, // [..., d]
20
- const scalar_t* __restrict__ weight, // [d]
21
- const float eps,
22
- const int d
23
- ) {
24
 
25
  const int64_t token_idx = blockIdx.x;
26
  const int64_t vec_idx = threadIdx.x;
@@ -44,15 +41,13 @@ __global__ void rms_norm_kernel(
44
  }
45
 
46
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
47
- __global__ void rms_norm_backward_kernel(
48
- scalar_t* __restrict__ input_grad, // [..., d]
49
- acc_t* __restrict__ temp_weight_grad, // [..., d]
50
- const scalar_t* __restrict__ output_grad, // [..., d]
51
- const scalar_t* __restrict__ input, // [..., d]
52
- const scalar_t* __restrict__ weight, // [d]
53
- const float eps,
54
- const int d
55
- ) {
56
  const int64_t token_idx = blockIdx.x;
57
  const int64_t vec_idx = threadIdx.x;
58
  acc_t d_sum = 0.0f;
@@ -80,8 +75,7 @@ __global__ void rms_norm_backward_kernel(
80
  acc_t dy = output_grad[token_idx * d + idx];
81
  acc_t w = weight[idx];
82
 
83
- input_grad[token_idx * d + idx] =
84
- scale * dy * w - dxx * x;
85
 
86
  if (temp_weight_grad) {
87
  temp_weight_grad[token_idx * d + idx] = dy * x * scale;
@@ -89,14 +83,12 @@ __global__ void rms_norm_backward_kernel(
89
  }
90
  }
91
 
92
- } // namespace motif
93
-
94
 
95
- void rms_norm(torch::Tensor& out, // [..., d]
96
- const torch::Tensor& input, // [..., d]
97
- const torch::Tensor& weight, // [d]
98
- double eps)
99
- {
100
  AssertTensorShapeEqual(input, out, "input", "out");
101
  AssertTensorNotNull(weight, "weight");
102
  // TODO shape check
@@ -110,25 +102,20 @@ void rms_norm(torch::Tensor& out, // [..., d]
110
 
111
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
112
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
113
- MOTIF_DISPATCH_FLOATING_TYPES(
114
- input.scalar_type(), "rms_norm_kernel", [&] {
115
- motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
116
- <<<grid, block, 0, stream>>>(
117
- out.data_ptr<scalar_t>(),
118
- input.data_ptr<scalar_t>(),
119
- weight.data_ptr<scalar_t>(),
120
- eps, d);
121
- }
122
- );
123
  }
124
 
125
- void rms_norm_backward(
126
- torch::Tensor& input_grad, // [..., d]
127
- torch::Tensor& weight_grad, // [..., d]
128
- const torch::Tensor& output_grad, // [d]
129
- const torch::Tensor& input, // [d]
130
- const torch::Tensor& weight, // [d]
131
- double eps) {
132
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
133
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
134
  AssertTensorNotNull(weight, "weight");
@@ -143,24 +130,20 @@ void rms_norm_backward(
143
  dim3 block(BLOCK_SIZE);
144
 
145
  torch::Tensor temp_weight_grad =
146
- torch::empty({num_tokens, d},
147
- input.options().dtype(torch::kFloat));
148
 
149
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
150
 
151
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
152
  MOTIF_DISPATCH_FLOATING_TYPES(
153
- input.scalar_type(), "rms_norm_backward_kernel", [&] {
154
- motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
155
- <<<grid, block, 0, stream>>>(
156
- input_grad.data_ptr<scalar_t>(),
157
- temp_weight_grad.data_ptr<float>(),
158
- output_grad.data_ptr<scalar_t>(),
159
- input.data_ptr<scalar_t>(),
160
- weight.data_ptr<scalar_t>(),
161
- eps, d);
162
- }
163
- );
164
 
165
  if (weight_grad.defined()) {
166
  at::sum_out(weight_grad, temp_weight_grad, {0});
 
 
1
  #include <ATen/Functions.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
  #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/all.h>
5
 
6
  #include <cmath>
7
 
 
 
8
  #include "assert_utils.h"
9
  #include "atomic_utils.h"
10
  #include "block_reduce.h"
11
+ #include "cuda_compat.h"
12
+ #include "dispatch_utils.h"
13
 
14
  namespace motif {
15
 
16
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
+ __global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
18
+ const scalar_t *__restrict__ input, // [..., d]
19
+ const scalar_t *__restrict__ weight, // [d]
20
+ const float eps, const int d) {
 
 
 
21
 
22
  const int64_t token_idx = blockIdx.x;
23
  const int64_t vec_idx = threadIdx.x;
 
41
  }
42
 
43
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
44
+ __global__ void
45
+ rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
46
+ acc_t *__restrict__ temp_weight_grad, // [..., d]
47
+ const scalar_t *__restrict__ output_grad, // [..., d]
48
+ const scalar_t *__restrict__ input, // [..., d]
49
+ const scalar_t *__restrict__ weight, // [d]
50
+ const float eps, const int d) {
 
 
51
  const int64_t token_idx = blockIdx.x;
52
  const int64_t vec_idx = threadIdx.x;
53
  acc_t d_sum = 0.0f;
 
75
  acc_t dy = output_grad[token_idx * d + idx];
76
  acc_t w = weight[idx];
77
 
78
+ input_grad[token_idx * d + idx] = scale * dy * w - dxx * x;
 
79
 
80
  if (temp_weight_grad) {
81
  temp_weight_grad[token_idx * d + idx] = dy * x * scale;
 
83
  }
84
  }
85
 
86
+ } // namespace motif
 
87
 
88
+ void rms_norm(torch::Tensor &out, // [..., d]
89
+ const torch::Tensor &input, // [..., d]
90
+ const torch::Tensor &weight, // [d]
91
+ double eps) {
 
92
  AssertTensorShapeEqual(input, out, "input", "out");
93
  AssertTensorNotNull(weight, "weight");
94
  // TODO shape check
 
102
 
103
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
104
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
105
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
106
+ motif::rms_norm_kernel<scalar_t, float, BLOCK_SIZE>
107
+ <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(),
108
+ input.data_ptr<scalar_t>(),
109
+ weight.data_ptr<scalar_t>(), eps, d);
110
+ });
 
 
 
 
111
  }
112
 
113
+ void rms_norm_backward(torch::Tensor &input_grad, // [..., d]
114
+ torch::Tensor &weight_grad, // [..., d]
115
+ const torch::Tensor &output_grad, // [d]
116
+ const torch::Tensor &input, // [d]
117
+ const torch::Tensor &weight, // [d]
118
+ double eps) {
 
119
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
120
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
121
  AssertTensorNotNull(weight, "weight");
 
130
  dim3 block(BLOCK_SIZE);
131
 
132
  torch::Tensor temp_weight_grad =
133
+ torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
 
134
 
135
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
136
 
137
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
138
  MOTIF_DISPATCH_FLOATING_TYPES(
139
+ input.scalar_type(), "rms_norm_backward_kernel", [&] {
140
+ motif::rms_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
141
+ <<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(),
142
+ temp_weight_grad.data_ptr<float>(),
143
+ output_grad.data_ptr<scalar_t>(),
144
+ input.data_ptr<scalar_t>(),
145
+ weight.data_ptr<scalar_t>(), eps, d);
146
+ });
 
 
 
147
 
148
  if (weight_grad.defined()) {
149
  at::sum_out(weight_grad, temp_weight_grad, {0});
build/torch27-cxx11-cu118-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1655e52503ce7d0b7dabd55b97c1bd7d11071cbe0f80b9e810c443523638fd9b
3
- size 2994312
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd84c828d4c15e96d65d6c8f0eb7a945ee8167d92e978b2ebce03eeaf41e7fce
3
+ size 4405112
build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch27-cxx11-cu126-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:174dbe4375aa22fb34d9d23630b3bec4eeb95635ef681b665db0985e78cf5af3
3
- size 3027504
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caffcadbb99fbaa27e8a81d5ef508f2e1a798e7626d618c3cf5b0d387d2c8686
3
+ size 4618624
build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch27-cxx11-cu128-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:91d71ca84a19b393c22b269226a7b4ddadbf1feec73a80bd45f655179c7a53f5
3
- size 3987512
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b7c6ece8e8d316c4cc5fe46b1cec4422b2f61e9bb7240af71a2b4a35975d8e6
3
+ size 6676528
build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch27-cxx11-rocm63-x86_64-linux/activation/{_activation_cf68df1_dirty.abi3.so → _activation_f517c97_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab1037bf6b41bf2be1d00a6a0ed01a97a5e4d64dd0abaf509492ad31eea0a576
3
- size 2642976
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4be173820e2a4bf4b6b8de6b63faf6544b599d9b0583f650a940adaef4a048b3
3
+ size 2899184
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:012788f2064588edf60df24778dff33f8ca95e3b1aaf5243554735cd783dd7ed
3
- size 3032488
 
 
 
 
build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb222449350310f90f7271f34fcf9052c9eec28021fee0348130a8f239a97bf4
3
+ size 4571976
build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b1a65b79b750f550a09e6a1142b5151b03b2a60ec6115a264e6d8de3cac7ee5d
3
- size 4000920
 
 
 
 
build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:79be6527f579de1133e50a66310d7d0690649dcac63009a54b5e68809408f12a
3
+ size 6634208
build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd38039c3401b0f6a136f1761c7f396f5954f05e16d78ed1600d8325c1221781
3
- size 4059256
 
 
 
 
build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d95e4491d35cb022a6eaa2febbc555f203893f989a4fb1cc483b2632f141869
3
+ size 6687456
build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py CHANGED
@@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction
7
 
8
 
9
  class PolyNorm(nn.Module):
 
10
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
11
  super().__init__()
12
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
@@ -28,6 +29,7 @@ class PolyNorm(nn.Module):
28
 
29
 
30
  class RMSNorm(nn.Module):
 
31
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
32
  super().__init__()
33
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
 
7
 
8
 
9
  class PolyNorm(nn.Module):
10
+
11
  def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
12
  super().__init__()
13
  self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
 
29
 
30
 
31
  class RMSNorm(nn.Module):
32
+
33
  def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32):
34
  super().__init__()
35
  self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype))
build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py CHANGED
@@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
31
- bias_grad = (
32
- torch.empty(1, dtype=weight.dtype, device=weight.device)
33
- if ctx.needs_input_grad[2]
34
- else None
35
- )
36
-
37
- ops.poly_norm_backward(
38
- input_grad, weight_grad, bias_grad, output_grad, input, weight, eps
39
- )
40
 
41
  return input_grad, weight_grad, bias_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
+ bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device)
34
+ if ctx.needs_input_grad[2] else None)
35
+
36
+ ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad,
37
+ input, weight, eps)
 
 
38
 
39
  return input_grad, weight_grad, bias_grad, None
build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py CHANGED
@@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function):
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
- input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None
30
- weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None
 
 
31
 
32
- ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps)
 
33
 
34
  return input_grad, weight_grad, None
 
26
  input, weight = ctx.saved_tensors
27
  eps = ctx.eps
28
 
29
+ input_grad = torch.empty_like(
30
+ input) if ctx.needs_input_grad[0] else None
31
+ weight_grad = torch.empty_like(
32
+ weight) if ctx.needs_input_grad[1] else None
33
 
34
+ ops.rms_norm_backward(input_grad, weight_grad, output_grad, input,
35
+ weight, eps)
36
 
37
  return input_grad, weight_grad, None
build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8a75fc3e8648bbab973e3021720ed372ec8468f7a28b5b047640fd7198ab369
3
- size 2647872
 
 
 
 
build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58116124bb2b5d11de2753dd0c30a1e4c84759f18599da7016c791bad37528e9
3
+ size 2899984
build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_cf68df1_dirty
3
- ops = torch.ops._activation_cf68df1_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_cf68df1_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_f517c97_dirty
3
+ ops = torch.ops._activation_f517c97_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_f517c97_dirty::{op_name}"