diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51ab2aa04188ccf992534dd2a90ba9d0abf2b2cf --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,36 @@ +default_install_hook_types: + - pre-commit + - commit-msg +default_stages: + - pre-commit # Run locally + - manual # Run in CI +exclude: '(build|result)/.*' +repos: +- repo: https://github.com/google/yapf + rev: v0.43.0 + hooks: + - id: yapf + args: [--in-place, --verbose] +- repo: https://github.com/crate-ci/typos + rev: v1.34.0 + hooks: + - id: typos +- repo: https://github.com/PyCQA/isort + rev: 6.0.1 + hooks: + - id: isort +- repo: https://github.com/pre-commit/mirrors-clang-format + rev: v20.1.3 + hooks: + - id: clang-format + types_or: [c++, cuda] + args: [--style=file, --verbose] +- repo: https://github.com/jackdewinter/pymarkdown + rev: v0.9.29 + hooks: + - id: pymarkdown + args: [fix] +- repo: https://github.com/rhysd/actionlint + rev: v1.7.7 + hooks: + - id: actionlint diff --git a/README.md b/README.md index 6e3c8df80f785aa7e376414373bdcffd076bca3a..7996c6276b50e7aa8450eb991839c197f6df848a 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,7 @@ print(poly_norm(x)) - Test cases are from the Motif LLM - You can reproduce the results with: + ```bash cd tests pytest --run-perf --do-plot @@ -39,3 +40,47 @@ pytest --run-perf --do-plot ![PolyNorm Performance](./tests/perf.png) +## Pre-commit Hooks + +This project uses [pre-commit](https://pre-commit.com/) to automatically check and format code before commits. + +### Setup + +1. Install pre-commit: + + ```bash + pip install pre-commit + ``` + +2. Install the git hooks: + +```bash + pre-commit install + ``` + +Once installed, the configured hooks will run automatically on each commit. + +### Included Hooks + +The following tools are run via pre-commit: + +- **[yapf](https://github.com/google/yapf)** – Python code formatter +- **[typos](https://github.com/crate-ci/typos)** – Spell checker for common typos +- **[isort](https://github.com/PyCQA/isort)** – Organizes and sorts Python imports +- **[clang-format](https://clang.llvm.org/docs/ClangFormat.html)** – Formats C++/CUDA code (`--style=file`) +- **[pymarkdown](https://github.com/jackdewinter/pymarkdown)** – Lints and auto-fixes Markdown files +- **[actionlint](https://github.com/rhysd/actionlint)** – Validates GitHub Actions workflows + +### Usage + +- Run all checks on the entire codebase: + + ```bash + pre-commit run --all-files + ``` + +- Run a specific hook (example: isort): + + ```bash + pre-commit run isort --all-files + ``` diff --git a/activation/assert_utils.h b/activation/assert_utils.h index aeaf14c5040a87ddc1855981c09f33989e96e480..410e5544910358a654c1ab0f65c2aa6a2a773030 100644 --- a/activation/assert_utils.h +++ b/activation/assert_utils.h @@ -3,12 +3,15 @@ #include #include -inline void AssertTensorNotNull(const torch::Tensor &tensor, const std::string &name) { +inline void AssertTensorNotNull(const torch::Tensor &tensor, + const std::string &name) { TORCH_INTERNAL_ASSERT(tensor.defined(), name + " tensor should not be null."); } -inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::Tensor &tensor_b, - const std::string &name_a, const std::string &name_b) { +inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, + const torch::Tensor &tensor_b, + const std::string &name_a, + const std::string &name_b) { AssertTensorNotNull(tensor_a, name_a); AssertTensorNotNull(tensor_b, name_b); @@ -17,6 +20,7 @@ inline void AssertTensorShapeEqual(const torch::Tensor &tensor_a, const torch::T auto tensor_shape_b = tensor_b.sizes(); TORCH_INTERNAL_ASSERT(tensor_shape_a.equals(tensor_shape_b), - "{} tensor shape should be equal to {} tensor shape. (actual: {}, expected: {})", - name_a, name_b, tensor_shape_a, tensor_shape_b); + "{} tensor shape should be equal to {} tensor shape. " + "(actual: {}, expected: {})", + name_a, name_b, tensor_shape_a, tensor_shape_b); } diff --git a/activation/atomic_utils.h b/activation/atomic_utils.h index 130d0fd0b3adb63bde5e07f8ce56061c358db9bc..e516e5c37528b03e821e4aac515450f72e05e668 100644 --- a/activation/atomic_utils.h +++ b/activation/atomic_utils.h @@ -1,35 +1,38 @@ #pragma once -#include #include #include +#include namespace motif { -template -__device__ inline void atomic_add(scalar_t* address, acc_t value) { +template +__device__ inline void atomic_add(scalar_t *address, acc_t value) { // TODO: change assert to a static_assert if possible - assert(false && "Unsupported type for atomic_add"); + assert(false && "Unsupported type for atomic_add"); } -template<> -__device__ inline void atomic_add(float* address, float value) { - atomicAdd(address, value); +template <> +__device__ inline void atomic_add(float *address, float value) { + atomicAdd(address, value); } -template<> -__device__ inline void atomic_add(double* address, double value) { - atomicAdd(address, value); +template <> +__device__ inline void atomic_add(double *address, + double value) { + atomicAdd(address, value); } -template<> -__device__ inline void atomic_add(c10::BFloat16* _address, float value) { - volatile c10::BFloat16* address = const_cast(_address); +template <> +__device__ inline void atomic_add(c10::BFloat16 *_address, + float value) { + volatile c10::BFloat16 *address = + const_cast(_address); size_t offset = (size_t)address & 0x2; - volatile uint16_t* address_as_short = - reinterpret_cast(reinterpret_cast(address)); - volatile uint32_t* address_as_uint = - reinterpret_cast(reinterpret_cast(address) - offset); + volatile uint16_t *address_as_short = reinterpret_cast( + reinterpret_cast(address)); + volatile uint32_t *address_as_uint = reinterpret_cast( + reinterpret_cast(address) - offset); bool is_32bit_aligned = offset == 0; uint32_t current = address_as_uint[0]; @@ -39,21 +42,24 @@ __device__ inline void atomic_add(c10::BFloat16* _address, expected = current; c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits()); c10::BFloat16 next_bf16 = current_bf16 + value; - uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x - : (current & 0x0000ffff) | (next_bf16.x << 16); - current = atomicCAS(const_cast(address_as_uint), expected, next); + uint32_t next = is_32bit_aligned + ? (current & 0xffff0000) | next_bf16.x + : (current & 0x0000ffff) | (next_bf16.x << 16); + current = + atomicCAS(const_cast(address_as_uint), expected, next); } while (current != expected); } -template<> -__device__ inline void atomic_add(c10::Half* _address, float value) { - volatile c10::Half* address = const_cast(_address); +template <> +__device__ inline void atomic_add(c10::Half *_address, + float value) { + volatile c10::Half *address = const_cast(_address); size_t offset = (size_t)address & 0x2; - volatile uint16_t* address_as_short = - reinterpret_cast(reinterpret_cast(address)); - volatile uint32_t* address_as_uint = - reinterpret_cast(reinterpret_cast(address) - offset); + volatile uint16_t *address_as_short = reinterpret_cast( + reinterpret_cast(address)); + volatile uint32_t *address_as_uint = reinterpret_cast( + reinterpret_cast(address) - offset); bool is_32bit_aligned = offset == 0; uint32_t current = address_as_uint[0]; @@ -63,11 +69,12 @@ __device__ inline void atomic_add(c10::Half* _address, float v expected = current; c10::Half current_half(address_as_short[0], c10::Half::from_bits()); c10::Half next_half = current_half + value; - uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x - : (current & 0x0000ffff) | (next_half.x << 16); - current = atomicCAS(const_cast(address_as_uint), expected, next); + uint32_t next = is_32bit_aligned + ? (current & 0xffff0000) | next_half.x + : (current & 0x0000ffff) | (next_half.x << 16); + current = + atomicCAS(const_cast(address_as_uint), expected, next); } while (current != expected); - } } // namespace motif diff --git a/activation/block_reduce.h b/activation/block_reduce.h index a1f3b9b462c8a7574a5b8a33e6e9650824f54d00..61c56e3b4f71646ddfeb5f879bc8169273c285f8 100644 --- a/activation/block_reduce.h +++ b/activation/block_reduce.h @@ -1,7 +1,8 @@ namespace motif { template -__device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) { +__device__ acc_t _block_reduce_sum(acc_t *shared, const float val, + const int d) { // TODO: Optimize with warp-level primitives __syncthreads(); @@ -17,4 +18,4 @@ __device__ acc_t _block_reduce_sum(acc_t* shared, const float val, const int d) return shared[0]; } -} // motif +} // namespace motif diff --git a/activation/cuda_compat.h b/activation/cuda_compat.h index 1235f55df5e252c0d0216c73c7986f485200723c..389840e90e61d5c84563be7781cf7251b35f5248 100644 --- a/activation/cuda_compat.h +++ b/activation/cuda_compat.h @@ -1,18 +1,20 @@ #pragma once -#ifdef USE_ROCM - #include +#ifndef USE_ROCM +#include +#else +#include +#include #endif #ifndef USE_ROCM - #define WARP_SIZE 32 +#define WARP_SIZE 32 #else - #define WARP_SIZE warpSize +#define WARP_SIZE warpSize #endif #ifndef USE_ROCM - #define VLLM_LDG(arg) __ldg(arg) +#define VLLM_LDG(arg) __ldg(arg) #else - #define VLLM_LDG(arg) *(arg) +#define VLLM_LDG(arg) *(arg) #endif - diff --git a/activation/dispatch_utils.h b/activation/dispatch_utils.h index a85d50e71f2b5faaa4d887ff3c7c808785f2bc05..3a48f364091840aedcf7c942c63be424a276b230 100644 --- a/activation/dispatch_utils.h +++ b/activation/dispatch_utils.h @@ -6,10 +6,11 @@ #include -#define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ +#define MOTIF_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) -#define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +#define MOTIF_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + MOTIF_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) diff --git a/activation/poly_norm.cu b/activation/poly_norm.cu index 28edbb6a9caf9b97e7bd8e559b33fd4a91a2403a..a1dcdec4accc3d14602a0fe32701135b24190f74 100644 --- a/activation/poly_norm.cu +++ b/activation/poly_norm.cu @@ -1,246 +1,555 @@ -#include #include -#include +#include #include +#include #include -#include "cuda_compat.h" -#include "dispatch_utils.h" #include "assert_utils.h" #include "atomic_utils.h" #include "block_reduce.h" +#include "cuda_compat.h" +#include "dispatch_utils.h" namespace motif { -template -__global__ void poly_norm_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const scalar_t* __restrict__ weight, // [3] - const scalar_t* __restrict__ bias, // [1] - const float eps, - const int d - ) { +template struct alignas(sizeof(type) * N) type_vec_t { + type data[N]; +}; + +template +__global__ std::enable_if_t<(width > 0)> +poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const scalar_t *__restrict__ bias, // [1] + const float eps, const int d) { + using vec_t = type_vec_t; + + const int vec_d = d / width; + const int64_t vec_offset = blockIdx.x * vec_d; + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = static_cast(x_vec.data[i]); + acc_t x2 = x1 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x4 * x2; + + sum2 += x2; + sum4 += x4; + sum6 += x6; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); + __syncthreads(); + sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); + __syncthreads(); + sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + + __shared__ acc_t s_bias; + + __shared__ acc_t s_w2_inv_std1; + __shared__ acc_t s_w1_inv_std2; + __shared__ acc_t s_w0_inv_std3; + + if (threadIdx.x == 0) { + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; + s_bias = bias[0]; + + s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2; + s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1; + s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0; + } + __syncthreads(); + + acc_t w2_inv_std1 = s_w2_inv_std1; + acc_t w1_inv_std2 = s_w1_inv_std2; + acc_t w0_inv_std3 = s_w0_inv_std3; + acc_t bias_reg = s_bias; + + vec_t *__restrict__ output_vec = reinterpret_cast(out); + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t y_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = static_cast(x_vec.data[i]); + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + + acc_t y = + x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; + + y_vec.data[i] = static_cast(y); + } + output_vec[vec_offset + idx] = y_vec; + } +} + +template +__global__ std::enable_if_t<(width == 0)> +poly_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const scalar_t *__restrict__ bias, // [1] + const float eps, const int d) { const int64_t token_idx = blockIdx.x; - acc_t sum = 0.0f; - acc_t sum_square = 0.0f; - acc_t sum_cube = 0.0f; + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - acc_t x = input[token_idx * d + idx]; - sum += pow(x, 2.0f); - sum_square += pow(x, 4.0f); - sum_cube += pow(x, 6.0f); + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x4 * x2; + + sum2 += x2; + sum4 += x4; + sum6 += x6; } - __shared__ acc_t shared[BLOCK_SIZE]; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; - acc_t mean = _block_reduce_sum(shared, sum, d) / d; - acc_t mean_square = _block_reduce_sum(shared, sum_square, d) / d; - acc_t mean_cube = _block_reduce_sum(shared, sum_cube, d) / d; + sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); + __syncthreads(); + sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); + __syncthreads(); + sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); - acc_t w0 = weight[0]; - acc_t w1 = weight[1]; - acc_t w2 = weight[2]; - acc_t b = bias[0]; + __shared__ acc_t s_bias; + + __shared__ acc_t s_w2_inv_std1; + __shared__ acc_t s_w1_inv_std2; + __shared__ acc_t s_w0_inv_std3; + + if (threadIdx.x == 0) { + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; + s_bias = bias[0]; + + s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2; + s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1; + s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0; + } + __syncthreads(); - acc_t divisor = sqrt(mean + eps); - acc_t divisor_square = sqrt(mean_square + eps); - acc_t divisor_cube = sqrt(mean_cube + eps); + acc_t w2_inv_std1 = s_w2_inv_std1; + acc_t w1_inv_std2 = s_w1_inv_std2; + acc_t w0_inv_std3 = s_w0_inv_std3; + acc_t bias_reg = s_bias; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - acc_t x = input[token_idx * d + idx]; - acc_t x_square = pow(x, 2.0f); - acc_t x_cube = pow(x, 3.0f); - out[token_idx * d + idx] = w2 * x / divisor + - w1 * x_square / divisor_square + - w0 * x_cube / divisor_cube + b; + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + out[token_idx * d + idx] = + x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg; } } -template -__global__ void poly_norm_backward_kernel( - scalar_t* __restrict__ input_grad, // [..., d] - acc_t* __restrict__ temp_weight_grad, // [..., 3] - const scalar_t* __restrict__ output_grad, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const scalar_t* __restrict__ weight, // [3] - const float eps, - const int d - ) { - const int64_t token_idx = blockIdx.x; +template +__global__ std::enable_if_t<(width > 0)> +poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., 3] + acc_t *__restrict__ temp_bias_grad, // [..., 1] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const float eps, const int d) { + using vec_t = type_vec_t; + + const int vec_d = d / width; + const int64_t vec_offset = blockIdx.x * vec_d; + const vec_t *__restrict__ input_vec = reinterpret_cast(input); + const vec_t *__restrict__ output_grad_vec = + reinterpret_cast(output_grad); + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; + + acc_t sum_dx1 = 0.0f; + acc_t sum_dx2 = 0.0f; + acc_t sum_dx3 = 0.0f; + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t dy_vec = output_grad_vec[vec_offset + idx]; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = static_cast(x_vec.data[i]); + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x3 * x3; + + sum2 += x2; + sum4 += x4; + sum6 += x6; + + acc_t dy = static_cast(dy_vec.data[i]); + + sum_dx1 += dy * x1; + sum_dx2 += dy * x2; + sum_dx3 += dy * x3; + } + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + __syncthreads(); + sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); + __syncthreads(); + sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); + __syncthreads(); + sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + + __syncthreads(); + sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x); + __syncthreads(); + sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x); + __syncthreads(); + sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x); + + __shared__ acc_t s_mean2; + __shared__ acc_t s_mean4; + __shared__ acc_t s_mean6; + __shared__ acc_t s_sdx1; + __shared__ acc_t s_sdx2; + __shared__ acc_t s_sdx3; + + const acc_t inv_d = acc_t(1) / d; + + if (threadIdx.x == 0) { + s_mean2 = sum2 * inv_d + eps; + s_mean4 = sum4 * inv_d + eps; + s_mean6 = sum6 * inv_d + eps; + + s_sdx1 = sum_dx1 * inv_d; + s_sdx2 = sum_dx2 * inv_d; + s_sdx3 = sum_dx3 * inv_d; + } + __syncthreads(); acc_t w0 = weight[0]; acc_t w1 = weight[1]; acc_t w2 = weight[2]; - acc_t sum_2 = 0.0f; - acc_t sum_4 = 0.0f; - acc_t sum_6 = 0.0f; + acc_t mean2 = s_mean2; + acc_t mean4 = s_mean4; + acc_t mean6 = s_mean6; + acc_t sdx1 = s_sdx1; + acc_t sdx2 = s_sdx2; + acc_t sdx3 = s_sdx3; + + acc_t inv_std1 = rsqrtf(mean2); + acc_t inv_std2 = rsqrtf(mean4); + acc_t inv_std3 = rsqrtf(mean6); + + // inv_std / mean == powf(mean, -1.5) + acc_t c1 = w2 * inv_std1 / mean2; + acc_t c2 = acc_t(2) * w1 * inv_std2 / mean4; + acc_t c3 = acc_t(3) * w0 * inv_std3 / mean6; + + acc_t sum_dy = 0; + acc_t sum_dw0 = 0; + acc_t sum_dw1 = 0; + acc_t sum_dw2 = 0; + + vec_t *__restrict__ input_grad_vec = reinterpret_cast(input_grad); + + for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) { + vec_t x_vec = input_vec[vec_offset + idx]; + vec_t dy_vec = output_grad_vec[vec_offset + idx]; + vec_t dx_vec; + +#pragma unroll + for (int i = 0; i < width; ++i) { + acc_t x1 = static_cast(x_vec.data[i]); + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + acc_t dy = static_cast(dy_vec.data[i]); + + if (input_grad) { + acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3); + acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2); + acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1); + dx_vec.data[i] = static_cast(dx1 + dx2 + dx3); + } + + sum_dy += dy; + sum_dw0 += dy * (x3 * inv_std3); + sum_dw1 += dy * (x2 * inv_std2); + sum_dw2 += dy * (x1 * inv_std1); + } + + if (input_grad) { + input_grad_vec[vec_offset + idx] = dx_vec; + } + } + + sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x); + __syncthreads(); + sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x); + __syncthreads(); + sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x); + __syncthreads(); + sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x); + + if (threadIdx.x == 0) { + temp_bias_grad[blockIdx.x] = sum_dy; + temp_weight_grad[blockIdx.x * 3 + 0] = sum_dw0; + temp_weight_grad[blockIdx.x * 3 + 1] = sum_dw1; + temp_weight_grad[blockIdx.x * 3 + 2] = sum_dw2; + } +} + +template +__global__ std::enable_if_t<(width == 0)> +poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., 3] + acc_t *__restrict__ temp_bias_grad, // [..., 1] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [3] + const float eps, const int d) { + const int64_t token_idx = blockIdx.x; + + acc_t sum2 = 0.0f; + acc_t sum4 = 0.0f; + acc_t sum6 = 0.0f; - acc_t sum_dx_1 = 0.0f; - acc_t sum_dx_2 = 0.0f; - acc_t sum_dx_3 = 0.0f; + acc_t sum_dx1 = 0.0f; + acc_t sum_dx2 = 0.0f; + acc_t sum_dx3 = 0.0f; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { acc_t dy = output_grad[token_idx * d + idx]; - acc_t x_1 = input[token_idx * d + idx]; - acc_t x_2 = x_1 * x_1; - acc_t x_3 = x_2 * x_1; - acc_t x_4 = x_2 * x_2; - acc_t x_6 = x_3 * x_3; + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; + acc_t x4 = x2 * x2; + acc_t x6 = x3 * x3; - sum_2 += x_2; - sum_4 += x_4; - sum_6 += x_6; + sum2 += x2; + sum4 += x4; + sum6 += x6; - sum_dx_1 += dy * x_1; - sum_dx_2 += dy * x_2; - sum_dx_3 += dy * x_3; + sum_dx1 += dy * x1; + sum_dx2 += dy * x2; + sum_dx3 += dy * x3; } - __shared__ acc_t shared[BLOCK_SIZE]; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + __syncthreads(); + sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x); + __syncthreads(); + sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x); + __syncthreads(); + sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x); + + __syncthreads(); + sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x); + __syncthreads(); + sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x); + __syncthreads(); + sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x); + + __shared__ acc_t s_mean2; + __shared__ acc_t s_mean4; + __shared__ acc_t s_mean6; + __shared__ acc_t s_sdx1; + __shared__ acc_t s_sdx2; + __shared__ acc_t s_sdx3; + + const acc_t inv_d = acc_t(1) / d; + + if (threadIdx.x == 0) { + s_mean2 = sum2 * inv_d + eps; + s_mean4 = sum4 * inv_d + eps; + s_mean6 = sum6 * inv_d + eps; + + s_sdx1 = sum_dx1 * inv_d; + s_sdx2 = sum_dx2 * inv_d; + s_sdx3 = sum_dx3 * inv_d; + } + __syncthreads(); - acc_t mean_2 = _block_reduce_sum(shared, sum_2, d) / d + eps; - acc_t mean_4 = _block_reduce_sum(shared, sum_4, d) / d + eps; - acc_t mean_6 = _block_reduce_sum(shared, sum_6, d) / d + eps; + acc_t w0 = weight[0]; + acc_t w1 = weight[1]; + acc_t w2 = weight[2]; - sum_dx_1 = _block_reduce_sum(shared, sum_dx_1, d); - sum_dx_2 = _block_reduce_sum(shared, sum_dx_2, d); - sum_dx_3 = _block_reduce_sum(shared, sum_dx_3, d); + acc_t mean2 = s_mean2; + acc_t mean4 = s_mean4; + acc_t mean6 = s_mean6; + acc_t sdx1 = s_sdx1; + acc_t sdx2 = s_sdx2; + acc_t sdx3 = s_sdx3; - acc_t _mean_2 = powf(mean_2, -1.5); - acc_t _mean_4 = powf(mean_4, -1.5); - acc_t _mean_6 = powf(mean_6, -1.5); + acc_t inv_std1 = rsqrtf(mean2); + acc_t inv_std2 = rsqrtf(mean4); + acc_t inv_std3 = rsqrtf(mean6); - acc_t sq_mean_2 = sqrtf(mean_2); - acc_t sq_mean_4 = sqrtf(mean_4); - acc_t sq_mean_6 = sqrtf(mean_6); + // inv_std / mean == powf(mean, -1.5) + acc_t c1 = w2 * inv_std1 / mean2; + acc_t c2 = acc_t(2) * w1 * inv_std2 / mean4; + acc_t c3 = acc_t(3) * w0 * inv_std3 / mean6; + acc_t sum_dy = 0; acc_t sum_dw0 = 0; acc_t sum_dw1 = 0; acc_t sum_dw2 = 0; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { acc_t dy = output_grad[token_idx * d + idx]; - acc_t x_1 = input[token_idx * d + idx]; - acc_t x_2 = x_1 * x_1; - acc_t x_3 = x_2 * x_1; - - acc_t dx_3 = - _mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0; - acc_t dx_2 = - _mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1; - acc_t dx_1 = - _mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2; + acc_t x1 = input[token_idx * d + idx]; + acc_t x2 = x1 * x1; + acc_t x3 = x2 * x1; if (input_grad) { - input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3; + acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3); + acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2); + acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1); + input_grad[token_idx * d + idx] = dx1 + dx2 + dx3; } - sum_dw0 += dy * (x_3 / sq_mean_6); - sum_dw1 += dy * (x_2 / sq_mean_4); - sum_dw2 += dy * (x_1 / sq_mean_2); + sum_dy += dy; + sum_dw0 += dy * (x3 * inv_std3); + sum_dw1 += dy * (x2 * inv_std2); + sum_dw2 += dy * (x1 * inv_std1); } - if (temp_weight_grad) { - sum_dw0 = _block_reduce_sum(shared, sum_dw0, d); - sum_dw1 = _block_reduce_sum(shared, sum_dw1, d); - sum_dw2 = _block_reduce_sum(shared, sum_dw2, d); - - if (threadIdx.x == 0) { - temp_weight_grad[token_idx * 3 + 0] = sum_dw0; - temp_weight_grad[token_idx * 3 + 1] = sum_dw1; - temp_weight_grad[token_idx * 3 + 2] = sum_dw2; - } + sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x); + __syncthreads(); + sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x); + __syncthreads(); + sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x); + __syncthreads(); + sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x); + + if (threadIdx.x == 0) { + temp_bias_grad[token_idx] = sum_dy; + temp_weight_grad[token_idx * 3 + 0] = sum_dw0; + temp_weight_grad[token_idx * 3 + 1] = sum_dw1; + temp_weight_grad[token_idx * 3 + 2] = sum_dw2; } } -} // namespace motif - - -void poly_norm(torch::Tensor& out, // [..., d] - const torch::Tensor& input, // [..., d] - const torch::Tensor& weight, // [3] - const torch::Tensor& bias, // [1] - double eps) -{ +} // namespace motif + +#define LAUNCH_POLY_NORM(width) \ + MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \ + motif::poly_norm_kernel \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + weight.data_ptr(), bias.data_ptr(), eps, d); \ + }); + +void poly_norm(torch::Tensor &out, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &weight, // [3] + const torch::Tensor &bias, // [1] + double eps) { AssertTensorShapeEqual(input, out, "input", "out"); AssertTensorNotNull(weight, "weight"); AssertTensorNotNull(bias, "bias"); // TODO shape check - constexpr int BLOCK_SIZE = 256; - int d = input.size(-1); - int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_tokens = input.numel() / d; dim3 grid(num_tokens); - dim3 block(BLOCK_SIZE); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - MOTIF_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "poly_norm_kernel", [&] { - motif::poly_norm_kernel - <<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - bias.data_ptr(), eps, d); - } - ); + if (d % 8 == 0) { + LAUNCH_POLY_NORM(8); + } else { + LAUNCH_POLY_NORM(0); + } } -void poly_norm_backward( - torch::Tensor& input_grad, // [..., d] - torch::Tensor& weight_grad, // [..., d] - torch::Tensor& bias_grad, // [..., d] - const torch::Tensor& output_grad, // [3] - const torch::Tensor& input, // [3] - const torch::Tensor& weight, // [3] - double eps) { +#define LAUNCH_POLY_NORM_BACKWARD(width) \ + MOTIF_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "poly_norm_backward_kernel", [&] { \ + motif::poly_norm_backward_kernel \ + <<>>(input_grad.data_ptr(), \ + temp_weight_grad.data_ptr(), \ + temp_bias_grad.data_ptr(), \ + output_grad.data_ptr(), \ + input.data_ptr(), \ + weight.data_ptr(), eps, d); \ + }); + +void poly_norm_backward(torch::Tensor &input_grad, // [..., d] + torch::Tensor &weight_grad, // [3] + torch::Tensor &bias_grad, // [1] + const torch::Tensor &output_grad, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &weight, // [3] + double eps) { AssertTensorShapeEqual(input, input_grad, "input", "input_grad"); AssertTensorShapeEqual(input, output_grad, "input", "output_grad"); AssertTensorNotNull(weight, "weight"); // TODO shape check // weight_grad, bias_grad and input_grad can be nullable - constexpr int BLOCK_SIZE = 256; - int d = input.size(-1); - int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_tokens = input.numel() / d; dim3 grid(num_tokens); - dim3 block(BLOCK_SIZE); + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(d, max_block_size)); torch::Tensor temp_weight_grad = - torch::empty({num_tokens, 3}, - input.options().dtype(torch::kFloat)); - - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::empty({num_tokens, 3}, input.options().dtype(torch::kFloat)); + torch::Tensor temp_bias_grad = + torch::empty({num_tokens, 1}, output_grad.options().dtype(torch::kFloat)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - MOTIF_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "poly_norm_backward_kernel", [&] { - motif::poly_norm_backward_kernel - <<>>( - input_grad.data_ptr(), - temp_weight_grad.data_ptr(), - output_grad.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - eps, d); - } - ); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if (d % 8 == 0) { + LAUNCH_POLY_NORM_BACKWARD(8); + } else { + LAUNCH_POLY_NORM_BACKWARD(0); + } if (bias_grad.defined()) { - at::sum_out(bias_grad, output_grad); - bias_grad.resize_({1}); + torch::Tensor acc = torch::empty_like(bias_grad, temp_bias_grad.options()); + at::sum_out(acc, temp_bias_grad, {0}); + bias_grad.copy_(acc); } if (weight_grad.defined()) { - at::sum_out(weight_grad, temp_weight_grad, {0}); + torch::Tensor acc = + torch::empty_like(weight_grad, temp_weight_grad.options()); + at::sum_out(acc, temp_weight_grad, {0}); + weight_grad.copy_(acc); } } diff --git a/activation/rms_norm.cu b/activation/rms_norm.cu index fa5f6cfbdbd24d45d296b43cb2a56941e26eac76..9a2ffd0ac322acad34e56a1b2d53fde6178b714a 100644 --- a/activation/rms_norm.cu +++ b/activation/rms_norm.cu @@ -1,26 +1,23 @@ -#include #include -#include +#include #include +#include #include -#include "cuda_compat.h" -#include "dispatch_utils.h" #include "assert_utils.h" #include "atomic_utils.h" #include "block_reduce.h" +#include "cuda_compat.h" +#include "dispatch_utils.h" namespace motif { template -__global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const scalar_t* __restrict__ weight, // [d] - const float eps, - const int d - ) { +__global__ void rms_norm_kernel(scalar_t *__restrict__ out, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { const int64_t token_idx = blockIdx.x; const int64_t vec_idx = threadIdx.x; @@ -44,15 +41,13 @@ __global__ void rms_norm_kernel( } template -__global__ void rms_norm_backward_kernel( - scalar_t* __restrict__ input_grad, // [..., d] - acc_t* __restrict__ temp_weight_grad, // [..., d] - const scalar_t* __restrict__ output_grad, // [..., d] - const scalar_t* __restrict__ input, // [..., d] - const scalar_t* __restrict__ weight, // [d] - const float eps, - const int d - ) { +__global__ void +rms_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d] + acc_t *__restrict__ temp_weight_grad, // [..., d] + const scalar_t *__restrict__ output_grad, // [..., d] + const scalar_t *__restrict__ input, // [..., d] + const scalar_t *__restrict__ weight, // [d] + const float eps, const int d) { const int64_t token_idx = blockIdx.x; const int64_t vec_idx = threadIdx.x; acc_t d_sum = 0.0f; @@ -80,8 +75,7 @@ __global__ void rms_norm_backward_kernel( acc_t dy = output_grad[token_idx * d + idx]; acc_t w = weight[idx]; - input_grad[token_idx * d + idx] = - scale * dy * w - dxx * x; + input_grad[token_idx * d + idx] = scale * dy * w - dxx * x; if (temp_weight_grad) { temp_weight_grad[token_idx * d + idx] = dy * x * scale; @@ -89,14 +83,12 @@ __global__ void rms_norm_backward_kernel( } } -} // namespace motif - +} // namespace motif -void rms_norm(torch::Tensor& out, // [..., d] - const torch::Tensor& input, // [..., d] - const torch::Tensor& weight, // [d] - double eps) -{ +void rms_norm(torch::Tensor &out, // [..., d] + const torch::Tensor &input, // [..., d] + const torch::Tensor &weight, // [d] + double eps) { AssertTensorShapeEqual(input, out, "input", "out"); AssertTensorNotNull(weight, "weight"); // TODO shape check @@ -110,25 +102,20 @@ void rms_norm(torch::Tensor& out, // [..., d] const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - MOTIF_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_kernel", [&] { - motif::rms_norm_kernel - <<>>( - out.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - eps, d); - } - ); + MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + motif::rms_norm_kernel + <<>>(out.data_ptr(), + input.data_ptr(), + weight.data_ptr(), eps, d); + }); } -void rms_norm_backward( - torch::Tensor& input_grad, // [..., d] - torch::Tensor& weight_grad, // [..., d] - const torch::Tensor& output_grad, // [d] - const torch::Tensor& input, // [d] - const torch::Tensor& weight, // [d] - double eps) { +void rms_norm_backward(torch::Tensor &input_grad, // [..., d] + torch::Tensor &weight_grad, // [..., d] + const torch::Tensor &output_grad, // [d] + const torch::Tensor &input, // [d] + const torch::Tensor &weight, // [d] + double eps) { AssertTensorShapeEqual(input, input_grad, "input", "input_grad"); AssertTensorShapeEqual(input, output_grad, "input", "output_grad"); AssertTensorNotNull(weight, "weight"); @@ -143,24 +130,20 @@ void rms_norm_backward( dim3 block(BLOCK_SIZE); torch::Tensor temp_weight_grad = - torch::empty({num_tokens, d}, - input.options().dtype(torch::kFloat)); + torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); MOTIF_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "rms_norm_backward_kernel", [&] { - motif::rms_norm_backward_kernel - <<>>( - input_grad.data_ptr(), - temp_weight_grad.data_ptr(), - output_grad.data_ptr(), - input.data_ptr(), - weight.data_ptr(), - eps, d); - } - ); + input.scalar_type(), "rms_norm_backward_kernel", [&] { + motif::rms_norm_backward_kernel + <<>>(input_grad.data_ptr(), + temp_weight_grad.data_ptr(), + output_grad.data_ptr(), + input.data_ptr(), + weight.data_ptr(), eps, d); + }); if (weight_grad.defined()) { at::sum_out(weight_grad, temp_weight_grad, {0}); diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index fee453a5f3d827219ce25b2491d12dd3883ef044..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:1655e52503ce7d0b7dabd55b97c1bd7d11071cbe0f80b9e810c443523638fd9b -size 2994312 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f3a874e78aac8a38f35e3d3aa4d26c892c9a0d66 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd84c828d4c15e96d65d6c8f0eb7a945ee8167d92e978b2ebce03eeaf41e7fce +size 4405112 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index 1a2c10a5ac919201397afa3d9ffa2f8c49f434b9..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:174dbe4375aa22fb34d9d23630b3bec4eeb95635ef681b665db0985e78cf5af3 -size 3027504 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..689760116de97c954865cd824732f04d2f746728 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:caffcadbb99fbaa27e8a81d5ef508f2e1a798e7626d618c3cf5b0d387d2c8686 +size 4618624 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index af7500bce257ce738f3a19d5b1330bc6dae90856..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:91d71ca84a19b393c22b269226a7b4ddadbf1feec73a80bd45f655179c7a53f5 -size 3987512 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..45881f2bf18843120634173e5a0974ebdcbe07c6 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b7c6ece8e8d316c4cc5fe46b1cec4422b2f61e9bb7240af71a2b4a35975d8e6 +size 6676528 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index d44ecea9b84cd8373497342429108adfcb3021cb..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:ab1037bf6b41bf2be1d00a6a0ed01a97a5e4d64dd0abaf509492ad31eea0a576 -size 2642976 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..6e05f5b3045576c970e67481e0182f9aaf5a88d2 --- /dev/null +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4be173820e2a4bf4b6b8de6b63faf6544b599d9b0583f650a940adaef4a048b3 +size 2899184 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index 8cb08e411d00e0c33bd32a334b13f032afabf622..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:012788f2064588edf60df24778dff33f8ca95e3b1aaf5243554735cd783dd7ed -size 3032488 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..6c12e8b587a01fe10f4e73cca22a5a27fd2e794a --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cb222449350310f90f7271f34fcf9052c9eec28021fee0348130a8f239a97bf4 +size 4571976 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index fe1ba22a5e5b4dd967450df3fb72dcf989ff3c49..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b1a65b79b750f550a09e6a1142b5151b03b2a60ec6115a264e6d8de3cac7ee5d -size 4000920 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ff5ceef3b840a9957dab36434074fa21417f6711 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79be6527f579de1133e50a66310d7d0690649dcac63009a54b5e68809408f12a +size 6634208 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index eb2ce397183a075e5f4923b2d2f551bf761331ef..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:fd38039c3401b0f6a136f1761c7f396f5954f05e16d78ed1600d8325c1221781 -size 4059256 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f7ab393218a3d825e10b9e1e838440d8a543ce19 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8d95e4491d35cb022a6eaa2febbc555f203893f989a4fb1cc483b2632f141869 +size 6687456 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index d7e1620ead5192b60900f00994daaca04dd7b6c0..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d8a75fc3e8648bbab973e3021720ed372ec8468f7a28b5b047640fd7198ab369 -size 2647872 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..1843d54d5917206c0947de8effc1cf347ea9e853 --- /dev/null +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:58116124bb2b5d11de2753dd0c30a1e4c84759f18599da7016c791bad37528e9 +size 2899984 diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm63-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so deleted file mode 100644 index 1e9b12a0c73081a688a6c2da9b54740dd3aeaa61..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_cf68df1_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:cccb0567a8f86f1f9e23a653a2e1f7177f4528cb1ecf8cbec42e40c60392eb39 -size 2633232 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..86ae5f11c05134ad7347aca293b13aeff2caf4c1 --- /dev/null +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_activation_f517c97_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65319d3d93ac3bf0f2939fa4e53ddfc8cd633b9e396cde3a97d63b9041ba03a7 +size 2885344 diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py index ecb351fdb0e6990bab29e42a4aecd1896b77a1b2..11632044e1d56e11f7646a5a027b0aea5439e2af 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _activation_cf68df1_dirty -ops = torch.ops._activation_cf68df1_dirty +from . import _activation_f517c97_dirty +ops = torch.ops._activation_f517c97_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_activation_cf68df1_dirty::{op_name}" \ No newline at end of file + return f"_activation_f517c97_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py +++ b/build/torch28-cxx11-rocm64-x86_64-linux/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/tests/conftest.py b/tests/conftest.py index de56b513008111d0efdb808165c3f15750761cfd..eda6569be97917dc009f15b3ca4f11136708aa0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,8 +33,7 @@ def plot(perf_results: list[PerfResult]): textfont=dict(size=14), textposition="outside", # width=[bar_width] * len(x_labels), - ) - ) + )) fig.add_trace( go.Bar( @@ -46,12 +45,12 @@ def plot(perf_results: list[PerfResult]): textfont=dict(size=14), textposition="outside", # width=[bar_width] * len(x_labels), - ) - ) + )) fig.update_layout( title=dict( - text="Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)", + text= + "Speedup over torch (higher is better) (MI250, torch 2.7, ROCm 6.3)", font=dict(size=24), ), legend=dict( @@ -96,12 +95,14 @@ def plot(perf_results: list[PerfResult]): def pytest_addoption(parser): - parser.addoption( - "--run-perf", action="store_true", default=False, help="Run perf tests" - ) - parser.addoption( - "--do-plot", action="store_true", default=False, help="Plot performance results" - ) + parser.addoption("--run-perf", + action="store_true", + default=False, + help="Run perf tests") + parser.addoption("--do-plot", + action="store_true", + default=False, + help="Plot performance results") @pytest.fixture @@ -117,10 +118,10 @@ def pytest_configure(config): if DO_PLOT and not run_perf: raise ValueError( "Cannot plot performance results without running performance tests. " - "Please use --run-perf option." - ) + "Please use --run-perf option.") - config.addinivalue_line("markers", "perf: mark test as performance-related") + config.addinivalue_line("markers", + "perf: mark test as performance-related") def pytest_collection_modifyitems(config, items): @@ -128,8 +129,7 @@ def pytest_collection_modifyitems(config, items): skip_perf = pytest.mark.skip(reason="need --run-perf option to run") skip_normal = pytest.mark.skip( - reason="normal tests skipped when --run-perf is used" - ) + reason="normal tests skipped when --run-perf is used") for item in items: if "perf" in item.keywords and not run_perf: item.add_marker(skip_perf) diff --git a/tests/kernels/allclose_default.py b/tests/kernels/allclose_default.py index 80eb1eeb9fb738d70efe28d64df98b2ff7223463..175cfe82fb74e5b573c27f708a036810f2188c8d 100644 --- a/tests/kernels/allclose_default.py +++ b/tests/kernels/allclose_default.py @@ -3,7 +3,11 @@ import torch # Reference default values of atol and rtol are from # https://github.com/pytorch/pytorch/blob/6d96beb6bec24d73ee3f080bac54d2104068f675/test/test_transformers.py#L67 default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float: 1e-5} -default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float: 1.3e-6} +default_rtol = { + torch.float16: 1e-3, + torch.bfloat16: 1.6e-2, + torch.float: 1.3e-6 +} def get_default_atol(output) -> float: diff --git a/tests/kernels/test_poly_norm.py b/tests/kernels/test_poly_norm.py index 796a5a136c8fd0b9bf9021c81962fb4636234d9a..dab3d26a722e7a1db6917e3745df165ed2f08ffb 100644 --- a/tests/kernels/test_poly_norm.py +++ b/tests/kernels/test_poly_norm.py @@ -13,23 +13,20 @@ DTYPES = [torch.float, torch.bfloat16, torch.half] NUM_TOKENS = [7, 13] # Arbitrary values for testing D = [513] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] def norm(x, eps: float) -> torch.Tensor: return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) -def poly_norm( - x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float -) -> torch.Tensor: +def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, + eps: float) -> torch.Tensor: x = x.float() - return ( - weight[0] * norm(x**3, eps) - + weight[1] * norm(x**2, eps) - + weight[2] * norm(x, eps) - + bias - ).to(weight.dtype) + return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) + + weight[2] * norm(x, eps) + bias).to(weight.dtype) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/tests/kernels/test_poly_norm_perf.py b/tests/kernels/test_poly_norm_perf.py index c89f9e7c55481c139df11d1be2622eaa6a3f6564..88c35f34e8acb6cb835eedb7c1d6657e16b28da3 100644 --- a/tests/kernels/test_poly_norm_perf.py +++ b/tests/kernels/test_poly_norm_perf.py @@ -94,7 +94,8 @@ def test_poly_norm( return start.elapsed_time(end) / NUM_REP kernel_time_ms = time_cuda(lambda: layer(x)) - torch_fn_time = time_cuda(lambda: torch_fn(x_ref, weight_ref, bias_ref, eps)) + torch_fn_time = time_cuda( + lambda: torch_fn(x_ref, weight_ref, bias_ref, eps)) PERF_RESULTS.append( PerfResult( @@ -103,11 +104,12 @@ def test_poly_norm( dtype=dtype, kernel_time_ms=kernel_time_ms, torch_time_ms=torch_fn_time, - ) - ) + )) - kernel_time_ms = time_cuda(lambda: mod_out.backward(out_grad, retain_graph=True)) - torch_fn_time = time_cuda(lambda: ref_out.backward(out_grad, retain_graph=True)) + kernel_time_ms = time_cuda( + lambda: mod_out.backward(out_grad, retain_graph=True)) + torch_fn_time = time_cuda( + lambda: ref_out.backward(out_grad, retain_graph=True)) PERF_RESULTS.append( PerfResult( @@ -116,5 +118,4 @@ def test_poly_norm( dtype=dtype, kernel_time_ms=kernel_time_ms, torch_time_ms=torch_fn_time, - ) - ) + )) diff --git a/tests/kernels/test_rms_norm.py b/tests/kernels/test_rms_norm.py index 9ea805d4fe9f97715aa7633e37381065a319b130..20fef91a1b9a82c380f04c273e4b169636f98cca 100644 --- a/tests/kernels/test_rms_norm.py +++ b/tests/kernels/test_rms_norm.py @@ -13,7 +13,9 @@ DTYPES = [torch.float, torch.bfloat16, torch.half] NUM_TOKENS = [7, 13] # Arbitrary values for testing D = [513] # Arbitrary values for testing SEEDS = [0] -CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index afcb16fcb0f3ca2bef8fa7b7f0b8b3b2a3acc4af..e5a2807b6ffb0f26ccf1cf61a28a9bb13b294c60 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -46,15 +46,19 @@ def fp8_allclose( """ Reference implementation of torch.allclose """ - torch._refs._check_close_args(name="torch.allclose", a=a, b=b, rtol=rtol, atol=atol) + torch._refs._check_close_args(name="torch.allclose", + a=a, + b=b, + rtol=rtol, + atol=atol) return bool( torch.all( - torch.isclose( - a.double(), b.double(), rtol=rtol, atol=atol, equal_nan=equal_nan - ) - ).item() - ) + torch.isclose(a.double(), + b.double(), + rtol=rtol, + atol=atol, + equal_nan=equal_nan)).item()) # A special version of op check that has a restricted default set of test_utils @@ -73,10 +77,9 @@ def opcheck( cond: bool = True, ) -> Dict[str, str]: with unittest.mock.patch("torch.allclose", new=fp8_allclose): - return ( - torch.library.opcheck( - op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception - ) - if cond - else {} - ) + return (torch.library.opcheck(op, + args, + kwargs, + test_utils=test_utils, + raise_exception=raise_exception) + if cond else {}) diff --git a/torch-ext/activation/layers.py b/torch-ext/activation/layers.py index 3824e5de50583f385215cf90adc03aff91653e2e..8ec01852a54649c04ba10f50aecb3f1a41576d18 100644 --- a/torch-ext/activation/layers.py +++ b/torch-ext/activation/layers.py @@ -7,6 +7,7 @@ from .rms_norm import RMSNormFunction class PolyNorm(nn.Module): + def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3) @@ -28,6 +29,7 @@ class PolyNorm(nn.Module): class RMSNorm(nn.Module): + def __init__(self, dim: int, eps=1e-6, dtype: torch.dtype = torch.float32): super().__init__() self.weight = torch.nn.Parameter(torch.ones(dim, dtype=dtype)) diff --git a/torch-ext/activation/poly_norm.py b/torch-ext/activation/poly_norm.py index ce14e5cd8078de06d964775e34e1e668df32493e..e9f13435bd79b865dca42a4d84a9fd7e9f3ea479 100644 --- a/torch-ext/activation/poly_norm.py +++ b/torch-ext/activation/poly_norm.py @@ -26,16 +26,14 @@ class PolyNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None - bias_grad = ( - torch.empty(1, dtype=weight.dtype, device=weight.device) - if ctx.needs_input_grad[2] - else None - ) - - ops.poly_norm_backward( - input_grad, weight_grad, bias_grad, output_grad, input, weight, eps - ) + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None + bias_grad = (torch.empty(1, dtype=weight.dtype, device=weight.device) + if ctx.needs_input_grad[2] else None) + + ops.poly_norm_backward(input_grad, weight_grad, bias_grad, output_grad, + input, weight, eps) return input_grad, weight_grad, bias_grad, None diff --git a/torch-ext/activation/rms_norm.py b/torch-ext/activation/rms_norm.py index 53df35e855d5d1591bc6fceff50ba81afdb2c873..4ce81d593f3edd8f4d14ce41f822e8547188c049 100644 --- a/torch-ext/activation/rms_norm.py +++ b/torch-ext/activation/rms_norm.py @@ -26,9 +26,12 @@ class RMSNormFunction(torch.autograd.Function): input, weight = ctx.saved_tensors eps = ctx.eps - input_grad = torch.empty_like(input) if ctx.needs_input_grad[0] else None - weight_grad = torch.empty_like(weight) if ctx.needs_input_grad[1] else None + input_grad = torch.empty_like( + input) if ctx.needs_input_grad[0] else None + weight_grad = torch.empty_like( + weight) if ctx.needs_input_grad[1] else None - ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, weight, eps) + ops.rms_norm_backward(input_grad, weight_grad, output_grad, input, + weight, eps) return input_grad, weight_grad, None diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp index bf1d4eaa15ee3880bc7c84b229dd5b8ce349b176..74fb10647ac4d9c3704d2cbf41f235d2c6495e56 100644 --- a/torch-ext/torch_binding.cpp +++ b/torch-ext/torch_binding.cpp @@ -5,18 +5,21 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Activation ops - ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, float eps) -> ()"); - ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! bias_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()"); + ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, " + "float eps) -> ()"); + ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! " + "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float " + "eps) -> ()"); ops.impl("poly_norm", torch::kCUDA, &poly_norm); ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward); // Activation ops - ops.def("rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()"); - ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor output_grad, Tensor input, Tensor weight, float eps) -> ()"); + ops.def( + "rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()"); + ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor " + "output_grad, Tensor input, Tensor weight, float eps) -> ()"); ops.impl("rms_norm", torch::kCUDA, &rms_norm); ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward); - - } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h index 4a55825a567d46197953318daff1180586da8bd0..a64ae48061cf3b34405793564ebd1c9d76ffb006 100644 --- a/torch-ext/torch_binding.h +++ b/torch-ext/torch_binding.h @@ -2,8 +2,18 @@ #include -void poly_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, const torch::Tensor &bias, double eps); -void poly_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, torch::Tensor& bias_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps); +void poly_norm(torch::Tensor &out, const torch::Tensor &input, + const torch::Tensor &weights, const torch::Tensor &bias, + double eps); +void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, + torch::Tensor &bias_grad, + const torch::Tensor &output_grad, + const torch::Tensor &input, const torch::Tensor &weight, + double eps); -void rms_norm(torch::Tensor &out, const torch::Tensor &input, const torch::Tensor &weights, double eps); -void rms_norm_backward(torch::Tensor& input_grad, torch::Tensor& weight_grad, const torch::Tensor& output_grad, const torch::Tensor& input, const torch::Tensor& weight, double eps); +void rms_norm(torch::Tensor &out, const torch::Tensor &input, + const torch::Tensor &weights, double eps); +void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, + const torch::Tensor &output_grad, + const torch::Tensor &input, const torch::Tensor &weight, + double eps);