TaehyunKim TaehyunKimMotif commited on
Commit
a1e5ca8
·
unverified ·
1 Parent(s): e5e2eeb

Fix fused add rms norm (#4)

Browse files

* make fused add rms norm return 2 outputs

* add builds

* add plot

---------

Co-authored-by: taehyun <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -2
  2. activation/fused_add_rms_norm.cu +221 -3
  3. activation/fused_mul_poly_norm.cu +4 -4
  4. benchmarks/cases/add_rms.py +2 -1
  5. benchmarks/common/bench_framework.py +8 -2
  6. benchmarks/common/diff_engine.py +4 -1
  7. benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png +0 -0
  8. benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png +0 -0
  9. benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png +0 -0
  10. benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png +0 -0
  11. benchmarks/plots/h100/poly/plot_poly-bwd-perf.png +0 -0
  12. benchmarks/plots/h100/poly/plot_poly-fwd-perf.png +0 -0
  13. benchmarks/plots/h100/rms/plot_rms-bwd-perf.png +0 -0
  14. benchmarks/plots/h100/rms/plot_rms-fwd-perf.png +0 -0
  15. benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png +0 -0
  16. benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png +0 -0
  17. build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py +1 -1
  18. build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  19. build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py +3 -3
  20. build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py +1 -1
  21. build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py +2 -8
  22. build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py +1 -1
  23. build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  24. build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
  25. build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py +1 -1
  26. build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py +2 -8
  27. build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py +1 -1
  28. build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  29. build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py +3 -3
  30. build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py +1 -1
  31. build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py +2 -8
  32. build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py +1 -1
  33. build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  34. build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
  35. build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py +1 -1
  36. build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py +2 -8
  37. build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py +1 -1
  38. build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  39. build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py +3 -3
  40. build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py +1 -1
  41. build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py +2 -8
  42. build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py +1 -1
  43. build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  44. build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py +3 -3
  45. build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py +1 -1
  46. build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py +2 -8
  47. build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py +1 -1
  48. build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so +3 -0
  49. build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py +3 -3
  50. build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py +1 -1
README.md CHANGED
@@ -18,13 +18,15 @@ Activation is a python package that contains custom CUDA-based activation kernel
18
 
19
  ```python
20
  y = x + residual
21
- out = rms_norm(y, weight, eps)
 
22
  ```
23
 
24
  - Fused as:
25
 
26
  ```python
27
- out = fused_add_rms_norm(x, residual, weight, eps)
 
28
  ```
29
 
30
  - **FusedMulPolyNorm**
 
18
 
19
  ```python
20
  y = x + residual
21
+ hidden_state = rms_norm(y, weight, eps)
22
+ out = y + some_op(hidden_state)
23
  ```
24
 
25
  - Fused as:
26
 
27
  ```python
28
+ hidden_state, y = fused_add_rms_norm(x, residual, weight, eps)
29
+ out = y + some_op(hidden_state)
30
  ```
31
 
32
  - **FusedMulPolyNorm**
activation/fused_add_rms_norm.cu CHANGED
@@ -117,9 +117,175 @@ fused_add_rms_norm_kernel(scalar_t *__restrict__ out, // [..., d]
117
  }
118
  }
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  } // namespace motif
121
 
122
- #define LAUNCH_RMS_NORM(width) \
123
  MOTIF_DISPATCH_FLOATING_TYPES( \
124
  input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
125
  motif::fused_add_rms_norm_kernel<scalar_t, float, width> \
@@ -150,8 +316,60 @@ void fused_add_rms_norm(torch::Tensor &out, // [..., d]
150
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
151
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
152
  if (d % 8 == 0) {
153
- LAUNCH_RMS_NORM(8);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  } else {
155
- LAUNCH_RMS_NORM(0);
 
 
 
 
 
 
 
156
  }
157
  }
 
117
  }
118
  }
119
 
120
+ template <typename scalar_t, typename acc_t, int width>
121
+ __global__ std::enable_if_t<(width > 0)> fused_add_rms_norm_backward_kernel(
122
+ scalar_t *__restrict__ input_grad, // [..., d]
123
+ acc_t *__restrict__ temp_weight_grad, // [..., d]
124
+ const scalar_t *__restrict__ output_grad, // [..., d]
125
+ const scalar_t *__restrict__ add_output_grad, // [..., d]
126
+ const scalar_t *__restrict__ input, // [..., d]
127
+ const scalar_t *__restrict__ weight, // [d]
128
+ const float eps, const int d) {
129
+ using vec_t = type_vec_t<scalar_t, width>;
130
+ using dw_vec_t = type_vec_t<acc_t, width>;
131
+
132
+ const int64_t token_idx = blockIdx.x;
133
+ const int64_t vec_idx = threadIdx.x;
134
+
135
+ const int vec_d = d / width;
136
+ const int64_t vec_offset = token_idx * vec_d;
137
+
138
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
139
+ const vec_t *__restrict__ output_grad_vec =
140
+ reinterpret_cast<const vec_t *>(output_grad);
141
+ const vec_t *__restrict__ weight_vec =
142
+ reinterpret_cast<const vec_t *>(weight);
143
+
144
+ acc_t d_sum = 0.0f;
145
+ acc_t sum_square = 0.0f;
146
+
147
+ for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) {
148
+ vec_t x_vec = input_vec[vec_offset + vidx];
149
+ vec_t dy_vec = output_grad_vec[vec_offset + vidx];
150
+ vec_t w_vec = weight_vec[vidx];
151
+
152
+ #pragma unroll
153
+ for (int i = 0; i < width; ++i) {
154
+ acc_t x = x_vec.data[i];
155
+ acc_t dy = dy_vec.data[i];
156
+ acc_t w = w_vec.data[i];
157
+ d_sum += dy * x * w;
158
+ sum_square += x * x;
159
+ }
160
+ }
161
+
162
+ using BlockReduce = cub::BlockReduce<float2, 1024>;
163
+ __shared__ typename BlockReduce::TempStorage reduceStore;
164
+ struct SumOp {
165
+ __device__ float2 operator()(const float2 &a, const float2 &b) const {
166
+ return make_float2(a.x + b.x, a.y + b.y);
167
+ }
168
+ };
169
+ float2 thread_sums = make_float2(d_sum, sum_square);
170
+ float2 block_sums =
171
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
172
+
173
+ d_sum = block_sums.x;
174
+ sum_square = block_sums.y;
175
+
176
+ __shared__ acc_t s_scale;
177
+ __shared__ acc_t s_dxx;
178
+
179
+ if (threadIdx.x == 0) {
180
+ acc_t scale = rsqrtf(sum_square / d + eps);
181
+ s_dxx = d_sum * scale * scale * scale / d;
182
+ s_scale = scale;
183
+ }
184
+ __syncthreads();
185
+ acc_t scale = s_scale;
186
+ acc_t dxx = s_dxx;
187
+ vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
188
+ dw_vec_t *__restrict__ temp_weight_grad_vec =
189
+ reinterpret_cast<dw_vec_t *>(temp_weight_grad);
190
+ const vec_t *__restrict__ add_output_grad_vec =
191
+ reinterpret_cast<const vec_t *>(add_output_grad);
192
+
193
+ for (int64_t vidx = vec_idx; vidx < vec_d; vidx += blockDim.x) {
194
+ vec_t x_vec = input_vec[vec_offset + vidx];
195
+ vec_t dy_vec = output_grad_vec[vec_offset + vidx];
196
+ vec_t da_vec = add_output_grad_vec[vec_offset + vidx];
197
+ vec_t w_vec = weight_vec[vidx];
198
+
199
+ vec_t in_grad_vec;
200
+ dw_vec_t tw_grad_vec;
201
+
202
+ #pragma unroll
203
+ for (int i = 0; i < width; ++i) {
204
+ acc_t x = x_vec.data[i];
205
+ acc_t dy = dy_vec.data[i];
206
+ acc_t w = w_vec.data[i];
207
+
208
+ if (input_grad) {
209
+ scalar_t da = da_vec.data[i];
210
+ scalar_t in_grad = scale * dy * w - dxx * x;
211
+ in_grad_vec.data[i] = in_grad + da;
212
+ }
213
+ tw_grad_vec.data[i] = dy * x * scale;
214
+ }
215
+
216
+ if (input_grad) {
217
+ input_grad_vec[vec_offset + vidx] = in_grad_vec;
218
+ }
219
+ temp_weight_grad_vec[vec_offset + vidx] = tw_grad_vec;
220
+ }
221
+ }
222
+
223
+ template <typename scalar_t, typename acc_t, int width>
224
+ __global__ std::enable_if_t<(width == 0)> fused_add_rms_norm_backward_kernel(
225
+ scalar_t *__restrict__ input_grad, // [..., d]
226
+ acc_t *__restrict__ temp_weight_grad, // [..., d]
227
+ const scalar_t *__restrict__ output_grad, // [..., d]
228
+ const scalar_t *__restrict__ add_output_grad, // [..., d]
229
+ const scalar_t *__restrict__ input, // [..., d]
230
+ const scalar_t *__restrict__ weight, // [d]
231
+ const float eps, const int d) {
232
+ const int64_t token_idx = blockIdx.x;
233
+ const int64_t vec_idx = threadIdx.x;
234
+ acc_t d_sum = 0.0f;
235
+ acc_t sum_square = 0.0f;
236
+
237
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
238
+ acc_t x = input[token_idx * d + idx];
239
+ acc_t dy = output_grad[token_idx * d + idx];
240
+ acc_t w = weight[idx];
241
+ d_sum += dy * x * w;
242
+ sum_square += x * x;
243
+ }
244
+
245
+ using BlockReduce = cub::BlockReduce<float2, 1024>;
246
+ __shared__ typename BlockReduce::TempStorage reduceStore;
247
+ struct SumOp {
248
+ __device__ float2 operator()(const float2 &a, const float2 &b) const {
249
+ return make_float2(a.x + b.x, a.y + b.y);
250
+ }
251
+ };
252
+ float2 thread_sums = make_float2(d_sum, sum_square);
253
+ float2 block_sums =
254
+ BlockReduce(reduceStore).Reduce(thread_sums, SumOp{}, blockDim.x);
255
+
256
+ d_sum = block_sums.x;
257
+ sum_square = block_sums.y;
258
+
259
+ __shared__ acc_t s_scale;
260
+ __shared__ acc_t s_dxx;
261
+
262
+ if (threadIdx.x == 0) {
263
+ acc_t scale = rsqrtf(sum_square / d + eps);
264
+ s_dxx = d_sum * scale * scale * scale / d;
265
+ s_scale = scale;
266
+ }
267
+ __syncthreads();
268
+
269
+ acc_t scale = s_scale;
270
+ acc_t dxx = s_dxx;
271
+
272
+ for (int64_t idx = vec_idx; idx < d; idx += blockDim.x) {
273
+ acc_t x = input[token_idx * d + idx];
274
+ acc_t dy = output_grad[token_idx * d + idx];
275
+ acc_t w = weight[idx];
276
+
277
+ if (input_grad) {
278
+ scalar_t da = add_output_grad[token_idx * d + idx];
279
+ scalar_t in_grad = scale * dy * w - dxx * x;
280
+ input_grad[token_idx * d + idx] = in_grad + da;
281
+ }
282
+ temp_weight_grad[token_idx * d + idx] = dy * x * scale;
283
+ }
284
+ }
285
+
286
  } // namespace motif
287
 
288
+ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
289
  MOTIF_DISPATCH_FLOATING_TYPES( \
290
  input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
291
  motif::fused_add_rms_norm_kernel<scalar_t, float, width> \
 
316
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
317
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
318
  if (d % 8 == 0) {
319
+ LAUNCH_FUSED_ADD_RMS_NORM(8);
320
+ } else {
321
+ LAUNCH_FUSED_ADD_RMS_NORM(0);
322
+ }
323
+ }
324
+
325
+ #define LAUNCH_FUSED_ADD_RMS_NORM_BWD(width) \
326
+ MOTIF_DISPATCH_FLOATING_TYPES( \
327
+ input.scalar_type(), "fused_add_rms_norm_backward_kernel", [&] { \
328
+ motif::fused_add_rms_norm_backward_kernel<scalar_t, float, width> \
329
+ <<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(), \
330
+ temp_weight_grad.data_ptr<float>(), \
331
+ output_grad.data_ptr<scalar_t>(), \
332
+ add_output_grad.data_ptr<scalar_t>(), \
333
+ input.data_ptr<scalar_t>(), \
334
+ weight.data_ptr<scalar_t>(), eps, d); \
335
+ });
336
+
337
+ void fused_add_rms_norm_backward(
338
+ torch::Tensor &input_grad, // [..., d]
339
+ torch::Tensor &weight_grad, // [d]
340
+ const torch::Tensor &output_grad, // [..., d]
341
+ const torch::Tensor &add_output_grad, // [..., d]
342
+ const torch::Tensor &input, // [..., d]
343
+ const torch::Tensor &weight, // [d]
344
+ double eps) {
345
+ AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
346
+ AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
347
+ AssertTensorShapeEqual(input, output_grad, "input", "add_output_grad");
348
+ AssertTensorNotNull(weight, "weight");
349
+ // TODO shape check
350
+ // weight_grad, input_grad can be nullable
351
+
352
+ int d = input.size(-1);
353
+ int64_t num_tokens = input.numel() / input.size(-1);
354
+ dim3 grid(num_tokens);
355
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
356
+ dim3 block(std::min(d, max_block_size));
357
+
358
+ torch::Tensor temp_weight_grad =
359
+ torch::empty({num_tokens, d}, input.options().dtype(torch::kFloat));
360
+
361
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
362
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
363
+ if (d % 8 == 0) {
364
+ LAUNCH_FUSED_ADD_RMS_NORM_BWD(8);
365
  } else {
366
+ LAUNCH_FUSED_ADD_RMS_NORM_BWD(0);
367
+ }
368
+
369
+ if (weight_grad.defined()) {
370
+ torch::Tensor acc =
371
+ torch::empty_like(weight_grad, temp_weight_grad.options());
372
+ at::sum_out(acc, temp_weight_grad, {0});
373
+ weight_grad.copy_(acc);
374
  }
375
  }
activation/fused_mul_poly_norm.cu CHANGED
@@ -573,7 +573,7 @@ void fused_mul_poly_norm(torch::Tensor &out, // [..., d]
573
  }
574
  }
575
 
576
- #define LAUNCH_POLY_NORM_BACKWARD(width) \
577
  MOTIF_DISPATCH_FLOATING_TYPES( \
578
  input.scalar_type(), "fused_mul_poly_norm_backward_kernel", [&] { \
579
  motif::fused_mul_poly_norm_backward_kernel<scalar_t, float, width> \
@@ -620,11 +620,11 @@ void fused_mul_poly_norm_backward(torch::Tensor &input_grad, // [..., d]
620
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
621
 
622
  if (d % 8 == 0 && input.element_size() == 2) {
623
- LAUNCH_POLY_NORM_BACKWARD(8);
624
  } else if (d % 4 == 0 && input.element_size() == 4) {
625
- LAUNCH_POLY_NORM_BACKWARD(4);
626
  } else {
627
- LAUNCH_POLY_NORM_BACKWARD(0);
628
  }
629
 
630
  if (bias_grad.defined()) {
 
573
  }
574
  }
575
 
576
+ #define LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(width) \
577
  MOTIF_DISPATCH_FLOATING_TYPES( \
578
  input.scalar_type(), "fused_mul_poly_norm_backward_kernel", [&] { \
579
  motif::fused_mul_poly_norm_backward_kernel<scalar_t, float, width> \
 
620
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
621
 
622
  if (d % 8 == 0 && input.element_size() == 2) {
623
+ LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(8);
624
  } else if (d % 4 == 0 && input.element_size() == 4) {
625
+ LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(4);
626
  } else {
627
+ LAUNCH_FUSED_MUL_POLY_NORM_BACKWARD(0);
628
  }
629
 
630
  if (bias_grad.defined()) {
benchmarks/cases/add_rms.py CHANGED
@@ -12,7 +12,8 @@ class FusedAddRMSNorm(torch.nn.Module):
12
  self.eps = eps
13
 
14
  def forward(self, x, residual):
15
- return activation.rms_norm((x + residual), self.weight, self.eps)
 
16
 
17
 
18
  class AddRMS(DiffCase):
 
12
  self.eps = eps
13
 
14
  def forward(self, x, residual):
15
+ h = x + residual
16
+ return activation.rms_norm(h, self.weight, self.eps), h
17
 
18
 
19
  class AddRMS(DiffCase):
benchmarks/common/bench_framework.py CHANGED
@@ -149,7 +149,10 @@ def make_bwd_benchmark_for_case(
149
  obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
150
  y = case.forward(obj, I)
151
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
152
- g = torch.randn_like(y)
 
 
 
153
  run = lambda: torch.autograd.grad(y,
154
  gin,
155
  g,
@@ -201,7 +204,10 @@ def make_bwd_benchmark_plot_for_case(
201
  obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
202
  y = case.forward(obj, I)
203
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
204
- g = torch.randn_like(y)
 
 
 
205
  run = lambda: torch.autograd.grad(y,
206
  gin,
207
  g,
 
149
  obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
150
  y = case.forward(obj, I)
151
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
152
+ if isinstance(y, torch.Tensor):
153
+ g = [torch.randn_like(y)]
154
+ else:
155
+ g = [torch.randn_like(r) for r in y]
156
  run = lambda: torch.autograd.grad(y,
157
  gin,
158
  g,
 
204
  obj = case.make_naive(I) if provider == "naive" else case.make_cuda(I)
205
  y = case.forward(obj, I)
206
  gin = list(case.grad_inputs(I)) + list(obj.parameters())
207
+ if isinstance(y, torch.Tensor):
208
+ g = [torch.randn_like(y)]
209
+ else:
210
+ g = [torch.randn_like(r) for r in y]
211
  run = lambda: torch.autograd.grad(y,
212
  gin,
213
  g,
benchmarks/common/diff_engine.py CHANGED
@@ -68,7 +68,10 @@ def calculate_diff(
68
  torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
69
  gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
70
  gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
71
- g = _unit_grad_like(y_n).to(device)
 
 
 
72
  ng = torch.autograd.grad(y_n,
73
  gin_n,
74
  g,
 
68
  torch.testing.assert_close(y_n, y_c, atol=atol, rtol=rtol)
69
  gin_n = list(case.grad_inputs(I_n)) + list(obj_n.parameters())
70
  gin_c = list(case.grad_inputs(I_c)) + list(obj_c.parameters())
71
+ if isinstance(y_n, torch.Tensor):
72
+ g = [_unit_grad_like(y_n).to(device)]
73
+ else:
74
+ g = [_unit_grad_like(r).to(device) for r in y_n]
75
  ng = torch.autograd.grad(y_n,
76
  gin_n,
77
  g,
benchmarks/plots/h100/add_rms/plot_add_rms-bwd-perf.png CHANGED
benchmarks/plots/h100/add_rms/plot_add_rms-fwd-perf.png CHANGED
benchmarks/plots/h100/mul_poly/plot_mul_poly-bwd-perf.png CHANGED
benchmarks/plots/h100/mul_poly/plot_mul_poly-fwd-perf.png CHANGED
benchmarks/plots/h100/poly/plot_poly-bwd-perf.png CHANGED
benchmarks/plots/h100/poly/plot_poly-fwd-perf.png CHANGED
benchmarks/plots/h100/rms/plot_rms-bwd-perf.png CHANGED
benchmarks/plots/h100/rms/plot_rms-fwd-perf.png CHANGED
benchmarks/plots/mi250/add_rms/plot_add_rms-bwd-perf.png CHANGED
benchmarks/plots/mi250/add_rms/plot_add_rms-fwd-perf.png CHANGED
build/torch27-cxx11-cu118-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch27-cxx11-cu118-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec9ea7edc8b27f7983e20d615ab470cef6b82975afc214becfddfd05a867a839
3
+ size 8600336
build/torch27-cxx11-cu118-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch27-cxx11-cu118-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """
build/torch27-cxx11-cu118-x86_64-linux/activation/rms_norm.py CHANGED
@@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
57
- ctx.mark_non_differentiable(add_output)
58
- ctx.set_materialize_grads(False)
59
  ctx.save_for_backward(weight, add_output)
60
  ctx.eps = eps
61
 
62
- # This function only needs one gradient
63
  @staticmethod
64
- def backward(ctx, output_grad, _):
65
  weight, add_output = ctx.saved_tensors
66
  eps = ctx.eps
67
 
68
- if output_grad is None:
69
- output_grad = torch.zeros_like(add_output)
70
-
71
  need_in = ctx.needs_input_grad[0]
72
  need_res = ctx.needs_input_grad[1]
73
 
@@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
  weight_grad = torch.empty_like(
77
  weight) if ctx.needs_input_grad[2] else None
78
 
79
- ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
80
  weight, eps)
81
  input_grad = grad if need_in else None
82
  residual_grad = grad if need_res else None
 
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
 
 
57
  ctx.save_for_backward(weight, add_output)
58
  ctx.eps = eps
59
 
 
60
  @staticmethod
61
+ def backward(ctx, output_grad, add_output_grad):
62
  weight, add_output = ctx.saved_tensors
63
  eps = ctx.eps
64
 
 
 
 
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
 
70
  weight_grad = torch.empty_like(
71
  weight) if ctx.needs_input_grad[2] else None
72
 
73
+ ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
  weight, eps)
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
build/torch27-cxx11-cu126-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch27-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d3511410cdc288d2fafc500223ed2e625e360f50fa341809cf892fb2c822924
3
+ size 8779000
build/torch27-cxx11-cu126-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch27-cxx11-cu126-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """
build/torch27-cxx11-cu126-x86_64-linux/activation/rms_norm.py CHANGED
@@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
57
- ctx.mark_non_differentiable(add_output)
58
- ctx.set_materialize_grads(False)
59
  ctx.save_for_backward(weight, add_output)
60
  ctx.eps = eps
61
 
62
- # This function only needs one gradient
63
  @staticmethod
64
- def backward(ctx, output_grad, _):
65
  weight, add_output = ctx.saved_tensors
66
  eps = ctx.eps
67
 
68
- if output_grad is None:
69
- output_grad = torch.zeros_like(add_output)
70
-
71
  need_in = ctx.needs_input_grad[0]
72
  need_res = ctx.needs_input_grad[1]
73
 
@@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
  weight_grad = torch.empty_like(
77
  weight) if ctx.needs_input_grad[2] else None
78
 
79
- ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
80
  weight, eps)
81
  input_grad = grad if need_in else None
82
  residual_grad = grad if need_res else None
 
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
 
 
57
  ctx.save_for_backward(weight, add_output)
58
  ctx.eps = eps
59
 
 
60
  @staticmethod
61
+ def backward(ctx, output_grad, add_output_grad):
62
  weight, add_output = ctx.saved_tensors
63
  eps = ctx.eps
64
 
 
 
 
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
 
70
  weight_grad = torch.empty_like(
71
  weight) if ctx.needs_input_grad[2] else None
72
 
73
+ ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
  weight, eps)
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
build/torch27-cxx11-cu128-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch27-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25efc9c32e4bd6609a8326025aad861cbf79b544893755fe44519c9df7224c40
3
+ size 13818872
build/torch27-cxx11-cu128-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch27-cxx11-cu128-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """
build/torch27-cxx11-cu128-x86_64-linux/activation/rms_norm.py CHANGED
@@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
57
- ctx.mark_non_differentiable(add_output)
58
- ctx.set_materialize_grads(False)
59
  ctx.save_for_backward(weight, add_output)
60
  ctx.eps = eps
61
 
62
- # This function only needs one gradient
63
  @staticmethod
64
- def backward(ctx, output_grad, _):
65
  weight, add_output = ctx.saved_tensors
66
  eps = ctx.eps
67
 
68
- if output_grad is None:
69
- output_grad = torch.zeros_like(add_output)
70
-
71
  need_in = ctx.needs_input_grad[0]
72
  need_res = ctx.needs_input_grad[1]
73
 
@@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
  weight_grad = torch.empty_like(
77
  weight) if ctx.needs_input_grad[2] else None
78
 
79
- ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
80
  weight, eps)
81
  input_grad = grad if need_in else None
82
  residual_grad = grad if need_res else None
 
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
 
 
57
  ctx.save_for_backward(weight, add_output)
58
  ctx.eps = eps
59
 
 
60
  @staticmethod
61
+ def backward(ctx, output_grad, add_output_grad):
62
  weight, add_output = ctx.saved_tensors
63
  eps = ctx.eps
64
 
 
 
 
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
 
70
  weight_grad = torch.empty_like(
71
  weight) if ctx.needs_input_grad[2] else None
72
 
73
+ ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
  weight, eps)
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
build/torch27-cxx11-rocm63-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c80d05690547f2842d416ebb85c9f830370373bc7e6c54ba08eec61b3690280f
3
+ size 4386744
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """
build/torch27-cxx11-rocm63-x86_64-linux/activation/rms_norm.py CHANGED
@@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
57
- ctx.mark_non_differentiable(add_output)
58
- ctx.set_materialize_grads(False)
59
  ctx.save_for_backward(weight, add_output)
60
  ctx.eps = eps
61
 
62
- # This function only needs one gradient
63
  @staticmethod
64
- def backward(ctx, output_grad, _):
65
  weight, add_output = ctx.saved_tensors
66
  eps = ctx.eps
67
 
68
- if output_grad is None:
69
- output_grad = torch.zeros_like(add_output)
70
-
71
  need_in = ctx.needs_input_grad[0]
72
  need_res = ctx.needs_input_grad[1]
73
 
@@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
  weight_grad = torch.empty_like(
77
  weight) if ctx.needs_input_grad[2] else None
78
 
79
- ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
80
  weight, eps)
81
  input_grad = grad if need_in else None
82
  residual_grad = grad if need_res else None
 
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
 
 
57
  ctx.save_for_backward(weight, add_output)
58
  ctx.eps = eps
59
 
 
60
  @staticmethod
61
+ def backward(ctx, output_grad, add_output_grad):
62
  weight, add_output = ctx.saved_tensors
63
  eps = ctx.eps
64
 
 
 
 
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
 
70
  weight_grad = torch.empty_like(
71
  weight) if ctx.needs_input_grad[2] else None
72
 
73
+ ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
  weight, eps)
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
build/torch28-cxx11-cu126-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch28-cxx11-cu126-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:440f5c17a7ddaf73c506bbc84fd1405e2e188b8ceaf4977910608be6b91e89bf
3
+ size 8730200
build/torch28-cxx11-cu126-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch28-cxx11-cu126-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """
build/torch28-cxx11-cu126-x86_64-linux/activation/rms_norm.py CHANGED
@@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
57
- ctx.mark_non_differentiable(add_output)
58
- ctx.set_materialize_grads(False)
59
  ctx.save_for_backward(weight, add_output)
60
  ctx.eps = eps
61
 
62
- # This function only needs one gradient
63
  @staticmethod
64
- def backward(ctx, output_grad, _):
65
  weight, add_output = ctx.saved_tensors
66
  eps = ctx.eps
67
 
68
- if output_grad is None:
69
- output_grad = torch.zeros_like(add_output)
70
-
71
  need_in = ctx.needs_input_grad[0]
72
  need_res = ctx.needs_input_grad[1]
73
 
@@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
  weight_grad = torch.empty_like(
77
  weight) if ctx.needs_input_grad[2] else None
78
 
79
- ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
80
  weight, eps)
81
  input_grad = grad if need_in else None
82
  residual_grad = grad if need_res else None
 
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
 
 
57
  ctx.save_for_backward(weight, add_output)
58
  ctx.eps = eps
59
 
 
60
  @staticmethod
61
+ def backward(ctx, output_grad, add_output_grad):
62
  weight, add_output = ctx.saved_tensors
63
  eps = ctx.eps
64
 
 
 
 
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
 
70
  weight_grad = torch.empty_like(
71
  weight) if ctx.needs_input_grad[2] else None
72
 
73
+ ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
  weight, eps)
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
build/torch28-cxx11-cu128-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch28-cxx11-cu128-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dfb6d468f9cef0239d4ea47f0a247fa721befc5b8db86e1cddfc25f1814b67a
3
+ size 13770064
build/torch28-cxx11-cu128-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch28-cxx11-cu128-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """
build/torch28-cxx11-cu128-x86_64-linux/activation/rms_norm.py CHANGED
@@ -54,20 +54,14 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
57
- ctx.mark_non_differentiable(add_output)
58
- ctx.set_materialize_grads(False)
59
  ctx.save_for_backward(weight, add_output)
60
  ctx.eps = eps
61
 
62
- # This function only needs one gradient
63
  @staticmethod
64
- def backward(ctx, output_grad, _):
65
  weight, add_output = ctx.saved_tensors
66
  eps = ctx.eps
67
 
68
- if output_grad is None:
69
- output_grad = torch.zeros_like(add_output)
70
-
71
  need_in = ctx.needs_input_grad[0]
72
  need_res = ctx.needs_input_grad[1]
73
 
@@ -76,7 +70,7 @@ class FusedAddRMSNormFunction(torch.autograd.Function):
76
  weight_grad = torch.empty_like(
77
  weight) if ctx.needs_input_grad[2] else None
78
 
79
- ops.rms_norm_backward(grad, weight_grad, output_grad, add_output,
80
  weight, eps)
81
  input_grad = grad if need_in else None
82
  residual_grad = grad if need_res else None
 
54
  def setup_context(ctx, inputs, outputs):
55
  _, _, weight, eps = inputs
56
  _, add_output = outputs
 
 
57
  ctx.save_for_backward(weight, add_output)
58
  ctx.eps = eps
59
 
 
60
  @staticmethod
61
+ def backward(ctx, output_grad, add_output_grad):
62
  weight, add_output = ctx.saved_tensors
63
  eps = ctx.eps
64
 
 
 
 
65
  need_in = ctx.needs_input_grad[0]
66
  need_res = ctx.needs_input_grad[1]
67
 
 
70
  weight_grad = torch.empty_like(
71
  weight) if ctx.needs_input_grad[2] else None
72
 
73
+ ops.fused_add_rms_norm_backward(grad, weight_grad, output_grad, add_output_grad, add_output,
74
  weight, eps)
75
  input_grad = grad if need_in else None
76
  residual_grad = grad if need_res else None
build/torch28-cxx11-cu129-x86_64-linux/activation/__init__.py CHANGED
@@ -39,7 +39,7 @@ def fused_add_rms_norm(
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
- return FusedAddRMSNormFunction.apply(x, residual, weight, eps)[0]
43
 
44
 
45
  __all__ = [
 
39
  weight: torch.Tensor,
40
  eps: float = 1e-6,
41
  ) -> None:
42
+ return FusedAddRMSNormFunction.apply(x, residual, weight, eps)
43
 
44
 
45
  __all__ = [
build/torch28-cxx11-cu129-x86_64-linux/activation/_activation_e5e2eeb_dirty.abi3.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0815a50e61497b357b2b90fc28602b3f53a25da1161edd2cb0b0fbebc7c62bf6
3
+ size 13757248
build/torch28-cxx11-cu129-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_20250907180255
3
- ops = torch.ops._activation_20250907180255
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_20250907180255::{op_name}"
 
1
  import torch
2
+ from . import _activation_e5e2eeb_dirty
3
+ ops = torch.ops._activation_e5e2eeb_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_e5e2eeb_dirty::{op_name}"
build/torch28-cxx11-cu129-x86_64-linux/activation/layers.py CHANGED
@@ -85,7 +85,7 @@ class FusedAddRMSNorm(nn.Module):
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
- self.eps)[0]
89
 
90
  def reset_parameters(self) -> None:
91
  """
 
85
  residual: torch.Tensor,
86
  ):
87
  return FusedAddRMSNormFunction.apply(x, residual, self.weight,
88
+ self.eps)
89
 
90
  def reset_parameters(self) -> None:
91
  """