ca1207 commited on
Commit
f8a7e6f
·
1 Parent(s): 18ec195

optimize poly norm kernel

Browse files
activation/cuda_compat.h CHANGED
@@ -1,18 +1,20 @@
1
  #pragma once
2
 
3
- #ifdef USE_ROCM
4
- #include <hip/hip_runtime.h>
 
 
 
5
  #endif
6
 
7
  #ifndef USE_ROCM
8
- #define WARP_SIZE 32
9
  #else
10
- #define WARP_SIZE warpSize
11
  #endif
12
 
13
  #ifndef USE_ROCM
14
- #define VLLM_LDG(arg) __ldg(arg)
15
  #else
16
- #define VLLM_LDG(arg) *(arg)
17
  #endif
18
-
 
1
  #pragma once
2
 
3
+ #ifndef USE_ROCM
4
+ #include <cub/cub.cuh>
5
+ #else
6
+ #include <hip/hip_runtime.h>
7
+ #include <hipcub/hipcub.hpp>
8
  #endif
9
 
10
  #ifndef USE_ROCM
11
+ #define WARP_SIZE 32
12
  #else
13
+ #define WARP_SIZE warpSize
14
  #endif
15
 
16
  #ifndef USE_ROCM
17
+ #define VLLM_LDG(arg) __ldg(arg)
18
  #else
19
+ #define VLLM_LDG(arg) *(arg)
20
  #endif
 
activation/poly_norm.cu CHANGED
@@ -1,246 +1,555 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
  #include <ATen/Functions.h>
3
- #include <torch/all.h>
4
  #include <c10/cuda/CUDAGuard.h>
 
5
 
6
  #include <cmath>
7
 
8
- #include "cuda_compat.h"
9
- #include "dispatch_utils.h"
10
  #include "assert_utils.h"
11
  #include "atomic_utils.h"
12
  #include "block_reduce.h"
 
 
13
 
14
  namespace motif {
15
 
16
- template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
- __global__ void poly_norm_kernel(
18
- scalar_t* __restrict__ out, // [..., d]
19
- const scalar_t* __restrict__ input, // [..., d]
20
- const scalar_t* __restrict__ weight, // [3]
21
- const scalar_t* __restrict__ bias, // [1]
22
- const float eps,
23
- const int d
24
- ) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  const int64_t token_idx = blockIdx.x;
26
 
27
- acc_t sum = 0.0f;
28
- acc_t sum_square = 0.0f;
29
- acc_t sum_cube = 0.0f;
30
 
31
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
32
- acc_t x = input[token_idx * d + idx];
33
- sum += pow(x, 2.0f);
34
- sum_square += pow(x, 4.0f);
35
- sum_cube += pow(x, 6.0f);
 
 
 
 
36
  }
37
 
38
- __shared__ acc_t shared[BLOCK_SIZE];
 
39
 
40
- acc_t mean = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum, d) / d;
41
- acc_t mean_square = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
42
- acc_t mean_cube = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_cube, d) / d;
 
 
43
 
44
- acc_t w0 = weight[0];
45
- acc_t w1 = weight[1];
46
- acc_t w2 = weight[2];
47
- acc_t b = bias[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- acc_t divisor = sqrt(mean + eps);
50
- acc_t divisor_square = sqrt(mean_square + eps);
51
- acc_t divisor_cube = sqrt(mean_cube + eps);
 
52
 
53
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
54
- acc_t x = input[token_idx * d + idx];
55
- acc_t x_square = pow(x, 2.0f);
56
- acc_t x_cube = pow(x, 3.0f);
57
- out[token_idx * d + idx] = w2 * x / divisor +
58
- w1 * x_square / divisor_square +
59
- w0 * x_cube / divisor_cube + b;
60
  }
61
  }
62
 
63
- template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
64
- __global__ void poly_norm_backward_kernel(
65
- scalar_t* __restrict__ input_grad, // [..., d]
66
- acc_t* __restrict__ temp_weight_grad, // [..., 3]
67
- const scalar_t* __restrict__ output_grad, // [..., d]
68
- const scalar_t* __restrict__ input, // [..., d]
69
- const scalar_t* __restrict__ weight, // [3]
70
- const float eps,
71
- const int d
72
- ) {
73
- const int64_t token_idx = blockIdx.x;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  acc_t w0 = weight[0];
76
  acc_t w1 = weight[1];
77
  acc_t w2 = weight[2];
78
 
79
- acc_t sum_2 = 0.0f;
80
- acc_t sum_4 = 0.0f;
81
- acc_t sum_6 = 0.0f;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- acc_t sum_dx_1 = 0.0f;
84
- acc_t sum_dx_2 = 0.0f;
85
- acc_t sum_dx_3 = 0.0f;
86
 
87
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
88
  acc_t dy = output_grad[token_idx * d + idx];
89
 
90
- acc_t x_1 = input[token_idx * d + idx];
91
- acc_t x_2 = x_1 * x_1;
92
- acc_t x_3 = x_2 * x_1;
93
- acc_t x_4 = x_2 * x_2;
94
- acc_t x_6 = x_3 * x_3;
95
 
96
- sum_2 += x_2;
97
- sum_4 += x_4;
98
- sum_6 += x_6;
99
 
100
- sum_dx_1 += dy * x_1;
101
- sum_dx_2 += dy * x_2;
102
- sum_dx_3 += dy * x_3;
103
  }
104
 
105
- __shared__ acc_t shared[BLOCK_SIZE];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- acc_t mean_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_2, d) / d + eps;
108
- acc_t mean_4 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_4, d) / d + eps;
109
- acc_t mean_6 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_6, d) / d + eps;
110
 
111
- sum_dx_1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_1, d);
112
- sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
113
- sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
 
 
 
114
 
115
- acc_t _mean_2 = powf(mean_2, -1.5);
116
- acc_t _mean_4 = powf(mean_4, -1.5);
117
- acc_t _mean_6 = powf(mean_6, -1.5);
118
 
119
- acc_t sq_mean_2 = sqrtf(mean_2);
120
- acc_t sq_mean_4 = sqrtf(mean_4);
121
- acc_t sq_mean_6 = sqrtf(mean_6);
 
122
 
 
123
  acc_t sum_dw0 = 0;
124
  acc_t sum_dw1 = 0;
125
  acc_t sum_dw2 = 0;
126
 
127
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
128
  acc_t dy = output_grad[token_idx * d + idx];
129
- acc_t x_1 = input[token_idx * d + idx];
130
- acc_t x_2 = x_1 * x_1;
131
- acc_t x_3 = x_2 * x_1;
132
-
133
- acc_t dx_3 =
134
- _mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
135
- acc_t dx_2 =
136
- _mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
137
- acc_t dx_1 =
138
- _mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
139
 
140
  if (input_grad) {
141
- input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
 
 
 
142
  }
143
 
144
- sum_dw0 += dy * (x_3 / sq_mean_6);
145
- sum_dw1 += dy * (x_2 / sq_mean_4);
146
- sum_dw2 += dy * (x_1 / sq_mean_2);
 
147
  }
148
 
149
- if (temp_weight_grad) {
150
- sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
151
- sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
152
- sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
153
-
154
- if (threadIdx.x == 0) {
155
- temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
156
- temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
157
- temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
158
- }
 
 
 
159
  }
160
  }
161
 
162
- } // namespace motif
163
-
164
-
165
- void poly_norm(torch::Tensor& out, // [..., d]
166
- const torch::Tensor& input, // [..., d]
167
- const torch::Tensor& weight, // [3]
168
- const torch::Tensor& bias, // [1]
169
- double eps)
170
- {
 
 
 
 
 
 
171
  AssertTensorShapeEqual(input, out, "input", "out");
172
  AssertTensorNotNull(weight, "weight");
173
  AssertTensorNotNull(bias, "bias");
174
  // TODO shape check
175
 
176
- constexpr int BLOCK_SIZE = 256;
177
-
178
  int d = input.size(-1);
179
- int64_t num_tokens = input.numel() / input.size(-1);
180
  dim3 grid(num_tokens);
181
- dim3 block(BLOCK_SIZE);
 
182
 
183
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
184
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
185
- MOTIF_DISPATCH_FLOATING_TYPES(
186
- input.scalar_type(), "poly_norm_kernel", [&] {
187
- motif::poly_norm_kernel<scalar_t, float, BLOCK_SIZE>
188
- <<<grid, block, 0, stream>>>(
189
- out.data_ptr<scalar_t>(),
190
- input.data_ptr<scalar_t>(),
191
- weight.data_ptr<scalar_t>(),
192
- bias.data_ptr<scalar_t>(), eps, d);
193
- }
194
- );
195
  }
196
 
197
- void poly_norm_backward(
198
- torch::Tensor& input_grad, // [..., d]
199
- torch::Tensor& weight_grad, // [..., d]
200
- torch::Tensor& bias_grad, // [..., d]
201
- const torch::Tensor& output_grad, // [3]
202
- const torch::Tensor& input, // [3]
203
- const torch::Tensor& weight, // [3]
204
- double eps) {
 
 
 
 
 
 
 
 
 
 
 
205
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
206
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
207
  AssertTensorNotNull(weight, "weight");
208
  // TODO shape check
209
  // weight_grad, bias_grad and input_grad can be nullable
210
 
211
- constexpr int BLOCK_SIZE = 256;
212
-
213
  int d = input.size(-1);
214
- int64_t num_tokens = input.numel() / input.size(-1);
215
  dim3 grid(num_tokens);
216
- dim3 block(BLOCK_SIZE);
 
217
 
218
  torch::Tensor temp_weight_grad =
219
- torch::empty({num_tokens, 3},
220
- input.options().dtype(torch::kFloat));
221
-
222
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
223
 
224
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
225
- MOTIF_DISPATCH_FLOATING_TYPES(
226
- input.scalar_type(), "poly_norm_backward_kernel", [&] {
227
- motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
228
- <<<grid, block, 0, stream>>>(
229
- input_grad.data_ptr<scalar_t>(),
230
- temp_weight_grad.data_ptr<float>(),
231
- output_grad.data_ptr<scalar_t>(),
232
- input.data_ptr<scalar_t>(),
233
- weight.data_ptr<scalar_t>(),
234
- eps, d);
235
- }
236
- );
237
 
238
  if (bias_grad.defined()) {
239
- at::sum_out(bias_grad, output_grad);
240
- bias_grad.resize_({1});
 
241
  }
242
 
243
  if (weight_grad.defined()) {
244
- at::sum_out(weight_grad, temp_weight_grad, {0});
 
 
 
245
  }
246
  }
 
 
1
  #include <ATen/Functions.h>
2
+ #include <ATen/cuda/CUDAContext.h>
3
  #include <c10/cuda/CUDAGuard.h>
4
+ #include <torch/all.h>
5
 
6
  #include <cmath>
7
 
 
 
8
  #include "assert_utils.h"
9
  #include "atomic_utils.h"
10
  #include "block_reduce.h"
11
+ #include "cuda_compat.h"
12
+ #include "dispatch_utils.h"
13
 
14
  namespace motif {
15
 
16
+ template <typename type, int N> struct alignas(sizeof(type) * N) type_vec_t {
17
+ type data[N];
18
+ };
19
+
20
+ template <typename scalar_t, typename acc_t, int width>
21
+ __global__ std::enable_if_t<(width > 0)>
22
+ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
23
+ const scalar_t *__restrict__ input, // [..., d]
24
+ const scalar_t *__restrict__ weight, // [3]
25
+ const scalar_t *__restrict__ bias, // [1]
26
+ const float eps, const int d) {
27
+ using vec_t = type_vec_t<scalar_t, width>;
28
+
29
+ const int vec_d = d / width;
30
+ const int64_t vec_offset = blockIdx.x * vec_d;
31
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
32
+
33
+ acc_t sum2 = 0.0f;
34
+ acc_t sum4 = 0.0f;
35
+ acc_t sum6 = 0.0f;
36
+
37
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
38
+ vec_t x_vec = input_vec[vec_offset + idx];
39
+
40
+ #pragma unroll
41
+ for (int i = 0; i < width; ++i) {
42
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
43
+ acc_t x2 = x1 * x1;
44
+ acc_t x4 = x2 * x2;
45
+ acc_t x6 = x4 * x2;
46
+
47
+ sum2 += x2;
48
+ sum4 += x4;
49
+ sum6 += x6;
50
+ }
51
+ }
52
+
53
+ using BlockReduce = cub::BlockReduce<float, 1024>;
54
+ __shared__ typename BlockReduce::TempStorage reduceStore;
55
+
56
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
57
+ __syncthreads();
58
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
59
+ __syncthreads();
60
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
61
+
62
+ __shared__ acc_t s_bias;
63
+
64
+ __shared__ acc_t s_w2_inv_std1;
65
+ __shared__ acc_t s_w1_inv_std2;
66
+ __shared__ acc_t s_w0_inv_std3;
67
+
68
+ if (threadIdx.x == 0) {
69
+ acc_t w0 = weight[0];
70
+ acc_t w1 = weight[1];
71
+ acc_t w2 = weight[2];
72
+ s_bias = bias[0];
73
+
74
+ s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
75
+ s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
76
+ s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
77
+ }
78
+ __syncthreads();
79
+
80
+ acc_t w2_inv_std1 = s_w2_inv_std1;
81
+ acc_t w1_inv_std2 = s_w1_inv_std2;
82
+ acc_t w0_inv_std3 = s_w0_inv_std3;
83
+ acc_t bias_reg = s_bias;
84
+
85
+ vec_t *__restrict__ output_vec = reinterpret_cast<vec_t *>(out);
86
+
87
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
88
+ vec_t x_vec = input_vec[vec_offset + idx];
89
+ vec_t y_vec;
90
+
91
+ #pragma unroll
92
+ for (int i = 0; i < width; ++i) {
93
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
94
+ acc_t x2 = x1 * x1;
95
+ acc_t x3 = x2 * x1;
96
+
97
+ acc_t y =
98
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
99
+
100
+ y_vec.data[i] = static_cast<scalar_t>(y);
101
+ }
102
+ output_vec[vec_offset + idx] = y_vec;
103
+ }
104
+ }
105
+
106
+ template <typename scalar_t, typename acc_t, int width>
107
+ __global__ std::enable_if_t<(width == 0)>
108
+ poly_norm_kernel(scalar_t *__restrict__ out, // [..., d]
109
+ const scalar_t *__restrict__ input, // [..., d]
110
+ const scalar_t *__restrict__ weight, // [3]
111
+ const scalar_t *__restrict__ bias, // [1]
112
+ const float eps, const int d) {
113
  const int64_t token_idx = blockIdx.x;
114
 
115
+ acc_t sum2 = 0.0f;
116
+ acc_t sum4 = 0.0f;
117
+ acc_t sum6 = 0.0f;
118
 
119
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
120
+ acc_t x1 = input[token_idx * d + idx];
121
+ acc_t x2 = x1 * x1;
122
+ acc_t x4 = x2 * x2;
123
+ acc_t x6 = x4 * x2;
124
+
125
+ sum2 += x2;
126
+ sum4 += x4;
127
+ sum6 += x6;
128
  }
129
 
130
+ using BlockReduce = cub::BlockReduce<float, 1024>;
131
+ __shared__ typename BlockReduce::TempStorage reduceStore;
132
 
133
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
134
+ __syncthreads();
135
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
136
+ __syncthreads();
137
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
138
 
139
+ __shared__ acc_t s_bias;
140
+
141
+ __shared__ acc_t s_w2_inv_std1;
142
+ __shared__ acc_t s_w1_inv_std2;
143
+ __shared__ acc_t s_w0_inv_std3;
144
+
145
+ if (threadIdx.x == 0) {
146
+ acc_t w0 = weight[0];
147
+ acc_t w1 = weight[1];
148
+ acc_t w2 = weight[2];
149
+ s_bias = bias[0];
150
+
151
+ s_w2_inv_std1 = rsqrtf(sum2 / d + eps) * w2;
152
+ s_w1_inv_std2 = rsqrtf(sum4 / d + eps) * w1;
153
+ s_w0_inv_std3 = rsqrtf(sum6 / d + eps) * w0;
154
+ }
155
+ __syncthreads();
156
 
157
+ acc_t w2_inv_std1 = s_w2_inv_std1;
158
+ acc_t w1_inv_std2 = s_w1_inv_std2;
159
+ acc_t w0_inv_std3 = s_w0_inv_std3;
160
+ acc_t bias_reg = s_bias;
161
 
162
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
163
+ acc_t x1 = input[token_idx * d + idx];
164
+ acc_t x2 = x1 * x1;
165
+ acc_t x3 = x2 * x1;
166
+ out[token_idx * d + idx] =
167
+ x1 * w2_inv_std1 + x2 * w1_inv_std2 + x3 * w0_inv_std3 + bias_reg;
 
168
  }
169
  }
170
 
171
+ template <typename scalar_t, typename acc_t, int width>
172
+ __global__ std::enable_if_t<(width > 0)>
173
+ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
174
+ acc_t *__restrict__ temp_weight_grad, // [..., 3]
175
+ acc_t *__restrict__ temp_bias_grad, // [..., 1]
176
+ const scalar_t *__restrict__ output_grad, // [..., d]
177
+ const scalar_t *__restrict__ input, // [..., d]
178
+ const scalar_t *__restrict__ weight, // [3]
179
+ const float eps, const int d) {
180
+ using vec_t = type_vec_t<scalar_t, width>;
181
+
182
+ const int vec_d = d / width;
183
+ const int64_t vec_offset = blockIdx.x * vec_d;
184
+ const vec_t *__restrict__ input_vec = reinterpret_cast<const vec_t *>(input);
185
+ const vec_t *__restrict__ output_grad_vec =
186
+ reinterpret_cast<const vec_t *>(output_grad);
187
+
188
+ acc_t sum2 = 0.0f;
189
+ acc_t sum4 = 0.0f;
190
+ acc_t sum6 = 0.0f;
191
+
192
+ acc_t sum_dx1 = 0.0f;
193
+ acc_t sum_dx2 = 0.0f;
194
+ acc_t sum_dx3 = 0.0f;
195
+
196
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
197
+ vec_t x_vec = input_vec[vec_offset + idx];
198
+ vec_t dy_vec = output_grad_vec[vec_offset + idx];
199
+
200
+ #pragma unroll
201
+ for (int i = 0; i < width; ++i) {
202
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
203
+ acc_t x2 = x1 * x1;
204
+ acc_t x3 = x2 * x1;
205
+ acc_t x4 = x2 * x2;
206
+ acc_t x6 = x3 * x3;
207
+
208
+ sum2 += x2;
209
+ sum4 += x4;
210
+ sum6 += x6;
211
+
212
+ acc_t dy = static_cast<acc_t>(dy_vec.data[i]);
213
+
214
+ sum_dx1 += dy * x1;
215
+ sum_dx2 += dy * x2;
216
+ sum_dx3 += dy * x3;
217
+ }
218
+ }
219
+
220
+ using BlockReduce = cub::BlockReduce<float, 1024>;
221
+ __shared__ typename BlockReduce::TempStorage reduceStore;
222
+
223
+ __syncthreads();
224
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
225
+ __syncthreads();
226
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
227
+ __syncthreads();
228
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
229
+
230
+ __syncthreads();
231
+ sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x);
232
+ __syncthreads();
233
+ sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x);
234
+ __syncthreads();
235
+ sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x);
236
+
237
+ __shared__ acc_t s_mean2;
238
+ __shared__ acc_t s_mean4;
239
+ __shared__ acc_t s_mean6;
240
+ __shared__ acc_t s_sdx1;
241
+ __shared__ acc_t s_sdx2;
242
+ __shared__ acc_t s_sdx3;
243
+
244
+ const acc_t inv_d = acc_t(1) / d;
245
+
246
+ if (threadIdx.x == 0) {
247
+ s_mean2 = sum2 * inv_d + eps;
248
+ s_mean4 = sum4 * inv_d + eps;
249
+ s_mean6 = sum6 * inv_d + eps;
250
+
251
+ s_sdx1 = sum_dx1 * inv_d;
252
+ s_sdx2 = sum_dx2 * inv_d;
253
+ s_sdx3 = sum_dx3 * inv_d;
254
+ }
255
+ __syncthreads();
256
 
257
  acc_t w0 = weight[0];
258
  acc_t w1 = weight[1];
259
  acc_t w2 = weight[2];
260
 
261
+ acc_t mean2 = s_mean2;
262
+ acc_t mean4 = s_mean4;
263
+ acc_t mean6 = s_mean6;
264
+ acc_t sdx1 = s_sdx1;
265
+ acc_t sdx2 = s_sdx2;
266
+ acc_t sdx3 = s_sdx3;
267
+
268
+ acc_t inv_std1 = rsqrtf(mean2);
269
+ acc_t inv_std2 = rsqrtf(mean4);
270
+ acc_t inv_std3 = rsqrtf(mean6);
271
+
272
+ // inv_std / mean == powf(mean, -1.5)
273
+ acc_t c1 = w2 * inv_std1 / mean2;
274
+ acc_t c2 = acc_t(2) * w1 * inv_std2 / mean4;
275
+ acc_t c3 = acc_t(3) * w0 * inv_std3 / mean6;
276
+
277
+ acc_t sum_dy = 0;
278
+ acc_t sum_dw0 = 0;
279
+ acc_t sum_dw1 = 0;
280
+ acc_t sum_dw2 = 0;
281
+
282
+ vec_t *__restrict__ input_grad_vec = reinterpret_cast<vec_t *>(input_grad);
283
+
284
+ for (int64_t idx = threadIdx.x; idx < vec_d; idx += blockDim.x) {
285
+ vec_t x_vec = input_vec[vec_offset + idx];
286
+ vec_t dy_vec = output_grad_vec[vec_offset + idx];
287
+ vec_t dx_vec;
288
+
289
+ #pragma unroll
290
+ for (int i = 0; i < width; ++i) {
291
+ acc_t x1 = static_cast<acc_t>(x_vec.data[i]);
292
+ acc_t x2 = x1 * x1;
293
+ acc_t x3 = x2 * x1;
294
+ acc_t dy = static_cast<acc_t>(dy_vec.data[i]);
295
+
296
+ if (input_grad) {
297
+ acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
298
+ acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
299
+ acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
300
+ dx_vec.data[i] = static_cast<scalar_t>(dx1 + dx2 + dx3);
301
+ }
302
+
303
+ sum_dy += dy;
304
+ sum_dw0 += dy * (x3 * inv_std3);
305
+ sum_dw1 += dy * (x2 * inv_std2);
306
+ sum_dw2 += dy * (x1 * inv_std1);
307
+ }
308
+
309
+ if (input_grad) {
310
+ input_grad_vec[vec_offset + idx] = dx_vec;
311
+ }
312
+ }
313
+
314
+ sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x);
315
+ __syncthreads();
316
+ sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x);
317
+ __syncthreads();
318
+ sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x);
319
+ __syncthreads();
320
+ sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x);
321
+
322
+ if (threadIdx.x == 0) {
323
+ temp_bias_grad[blockIdx.x] = sum_dy;
324
+ temp_weight_grad[blockIdx.x * 3 + 0] = sum_dw0;
325
+ temp_weight_grad[blockIdx.x * 3 + 1] = sum_dw1;
326
+ temp_weight_grad[blockIdx.x * 3 + 2] = sum_dw2;
327
+ }
328
+ }
329
+
330
+ template <typename scalar_t, typename acc_t, int width>
331
+ __global__ std::enable_if_t<(width == 0)>
332
+ poly_norm_backward_kernel(scalar_t *__restrict__ input_grad, // [..., d]
333
+ acc_t *__restrict__ temp_weight_grad, // [..., 3]
334
+ acc_t *__restrict__ temp_bias_grad, // [..., 1]
335
+ const scalar_t *__restrict__ output_grad, // [..., d]
336
+ const scalar_t *__restrict__ input, // [..., d]
337
+ const scalar_t *__restrict__ weight, // [3]
338
+ const float eps, const int d) {
339
+ const int64_t token_idx = blockIdx.x;
340
+
341
+ acc_t sum2 = 0.0f;
342
+ acc_t sum4 = 0.0f;
343
+ acc_t sum6 = 0.0f;
344
 
345
+ acc_t sum_dx1 = 0.0f;
346
+ acc_t sum_dx2 = 0.0f;
347
+ acc_t sum_dx3 = 0.0f;
348
 
349
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
350
  acc_t dy = output_grad[token_idx * d + idx];
351
 
352
+ acc_t x1 = input[token_idx * d + idx];
353
+ acc_t x2 = x1 * x1;
354
+ acc_t x3 = x2 * x1;
355
+ acc_t x4 = x2 * x2;
356
+ acc_t x6 = x3 * x3;
357
 
358
+ sum2 += x2;
359
+ sum4 += x4;
360
+ sum6 += x6;
361
 
362
+ sum_dx1 += dy * x1;
363
+ sum_dx2 += dy * x2;
364
+ sum_dx3 += dy * x3;
365
  }
366
 
367
+ using BlockReduce = cub::BlockReduce<float, 1024>;
368
+ __shared__ typename BlockReduce::TempStorage reduceStore;
369
+
370
+ __syncthreads();
371
+ sum2 = BlockReduce(reduceStore).Sum(sum2, blockDim.x);
372
+ __syncthreads();
373
+ sum4 = BlockReduce(reduceStore).Sum(sum4, blockDim.x);
374
+ __syncthreads();
375
+ sum6 = BlockReduce(reduceStore).Sum(sum6, blockDim.x);
376
+
377
+ __syncthreads();
378
+ sum_dx1 = BlockReduce(reduceStore).Sum(sum_dx1, blockDim.x);
379
+ __syncthreads();
380
+ sum_dx2 = BlockReduce(reduceStore).Sum(sum_dx2, blockDim.x);
381
+ __syncthreads();
382
+ sum_dx3 = BlockReduce(reduceStore).Sum(sum_dx3, blockDim.x);
383
+
384
+ __shared__ acc_t s_mean2;
385
+ __shared__ acc_t s_mean4;
386
+ __shared__ acc_t s_mean6;
387
+ __shared__ acc_t s_sdx1;
388
+ __shared__ acc_t s_sdx2;
389
+ __shared__ acc_t s_sdx3;
390
+
391
+ const acc_t inv_d = acc_t(1) / d;
392
+
393
+ if (threadIdx.x == 0) {
394
+ s_mean2 = sum2 * inv_d + eps;
395
+ s_mean4 = sum4 * inv_d + eps;
396
+ s_mean6 = sum6 * inv_d + eps;
397
+
398
+ s_sdx1 = sum_dx1 * inv_d;
399
+ s_sdx2 = sum_dx2 * inv_d;
400
+ s_sdx3 = sum_dx3 * inv_d;
401
+ }
402
+ __syncthreads();
403
 
404
+ acc_t w0 = weight[0];
405
+ acc_t w1 = weight[1];
406
+ acc_t w2 = weight[2];
407
 
408
+ acc_t mean2 = s_mean2;
409
+ acc_t mean4 = s_mean4;
410
+ acc_t mean6 = s_mean6;
411
+ acc_t sdx1 = s_sdx1;
412
+ acc_t sdx2 = s_sdx2;
413
+ acc_t sdx3 = s_sdx3;
414
 
415
+ acc_t inv_std1 = rsqrtf(mean2);
416
+ acc_t inv_std2 = rsqrtf(mean4);
417
+ acc_t inv_std3 = rsqrtf(mean6);
418
 
419
+ // inv_std / mean == powf(mean, -1.5)
420
+ acc_t c1 = w2 * inv_std1 / mean2;
421
+ acc_t c2 = acc_t(2) * w1 * inv_std2 / mean4;
422
+ acc_t c3 = acc_t(3) * w0 * inv_std3 / mean6;
423
 
424
+ acc_t sum_dy = 0;
425
  acc_t sum_dw0 = 0;
426
  acc_t sum_dw1 = 0;
427
  acc_t sum_dw2 = 0;
428
 
429
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
430
  acc_t dy = output_grad[token_idx * d + idx];
431
+ acc_t x1 = input[token_idx * d + idx];
432
+ acc_t x2 = x1 * x1;
433
+ acc_t x3 = x2 * x1;
 
 
 
 
 
 
 
434
 
435
  if (input_grad) {
436
+ acc_t dx3 = c3 * x2 * (dy * mean6 - x3 * sdx3);
437
+ acc_t dx2 = c2 * x1 * (dy * mean4 - x2 * sdx2);
438
+ acc_t dx1 = c1 * (dy * mean2 - x1 * sdx1);
439
+ input_grad[token_idx * d + idx] = dx1 + dx2 + dx3;
440
  }
441
 
442
+ sum_dy += dy;
443
+ sum_dw0 += dy * (x3 * inv_std3);
444
+ sum_dw1 += dy * (x2 * inv_std2);
445
+ sum_dw2 += dy * (x1 * inv_std1);
446
  }
447
 
448
+ sum_dy = BlockReduce(reduceStore).Sum(sum_dy, blockDim.x);
449
+ __syncthreads();
450
+ sum_dw0 = BlockReduce(reduceStore).Sum(sum_dw0, blockDim.x);
451
+ __syncthreads();
452
+ sum_dw1 = BlockReduce(reduceStore).Sum(sum_dw1, blockDim.x);
453
+ __syncthreads();
454
+ sum_dw2 = BlockReduce(reduceStore).Sum(sum_dw2, blockDim.x);
455
+
456
+ if (threadIdx.x == 0) {
457
+ temp_bias_grad[token_idx] = sum_dy;
458
+ temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
459
+ temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
460
+ temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
461
  }
462
  }
463
 
464
+ } // namespace motif
465
+
466
+ #define LAUNCH_POLY_NORM(width) \
467
+ MOTIF_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
468
+ motif::poly_norm_kernel<scalar_t, float, width> \
469
+ <<<grid, block, 0, stream>>>( \
470
+ out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
471
+ weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), eps, d); \
472
+ });
473
+
474
+ void poly_norm(torch::Tensor &out, // [..., d]
475
+ const torch::Tensor &input, // [..., d]
476
+ const torch::Tensor &weight, // [3]
477
+ const torch::Tensor &bias, // [1]
478
+ double eps) {
479
  AssertTensorShapeEqual(input, out, "input", "out");
480
  AssertTensorNotNull(weight, "weight");
481
  AssertTensorNotNull(bias, "bias");
482
  // TODO shape check
483
 
 
 
484
  int d = input.size(-1);
485
+ int64_t num_tokens = input.numel() / d;
486
  dim3 grid(num_tokens);
487
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
488
+ dim3 block(std::min(d, max_block_size));
489
 
490
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
491
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
492
+ if (d % 8 == 0) {
493
+ LAUNCH_POLY_NORM(8);
494
+ } else {
495
+ LAUNCH_POLY_NORM(0);
496
+ }
 
 
 
 
 
497
  }
498
 
499
+ #define LAUNCH_POLY_NORM_BACKWARD(width) \
500
+ MOTIF_DISPATCH_FLOATING_TYPES( \
501
+ input.scalar_type(), "poly_norm_backward_kernel", [&] { \
502
+ motif::poly_norm_backward_kernel<scalar_t, float, width> \
503
+ <<<grid, block, 0, stream>>>(input_grad.data_ptr<scalar_t>(), \
504
+ temp_weight_grad.data_ptr<float>(), \
505
+ temp_bias_grad.data_ptr<float>(), \
506
+ output_grad.data_ptr<scalar_t>(), \
507
+ input.data_ptr<scalar_t>(), \
508
+ weight.data_ptr<scalar_t>(), eps, d); \
509
+ });
510
+
511
+ void poly_norm_backward(torch::Tensor &input_grad, // [..., d]
512
+ torch::Tensor &weight_grad, // [3]
513
+ torch::Tensor &bias_grad, // [1]
514
+ const torch::Tensor &output_grad, // [..., d]
515
+ const torch::Tensor &input, // [..., d]
516
+ const torch::Tensor &weight, // [3]
517
+ double eps) {
518
  AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
519
  AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
520
  AssertTensorNotNull(weight, "weight");
521
  // TODO shape check
522
  // weight_grad, bias_grad and input_grad can be nullable
523
 
 
 
524
  int d = input.size(-1);
525
+ int64_t num_tokens = input.numel() / d;
526
  dim3 grid(num_tokens);
527
+ const int max_block_size = (num_tokens < 256) ? 1024 : 256;
528
+ dim3 block(std::min(d, max_block_size));
529
 
530
  torch::Tensor temp_weight_grad =
531
+ torch::empty({num_tokens, 3}, input.options().dtype(torch::kFloat));
532
+ torch::Tensor temp_bias_grad =
533
+ torch::empty({num_tokens, 1}, output_grad.options().dtype(torch::kFloat));
 
534
 
535
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
536
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
537
+ if (d % 8 == 0) {
538
+ LAUNCH_POLY_NORM_BACKWARD(8);
539
+ } else {
540
+ LAUNCH_POLY_NORM_BACKWARD(0);
541
+ }
 
 
 
 
 
 
542
 
543
  if (bias_grad.defined()) {
544
+ torch::Tensor acc = torch::empty_like(bias_grad, temp_bias_grad.options());
545
+ at::sum_out(acc, temp_bias_grad, {0});
546
+ bias_grad.copy_(acc);
547
  }
548
 
549
  if (weight_grad.defined()) {
550
+ torch::Tensor acc =
551
+ torch::empty_like(weight_grad, temp_weight_grad.options());
552
+ at::sum_out(acc, temp_weight_grad, {0});
553
+ weight_grad.copy_(acc);
554
  }
555
  }
activation/poly_norm_naive.cu DELETED
@@ -1,246 +0,0 @@
1
- #include <ATen/cuda/CUDAContext.h>
2
- #include <ATen/Functions.h>
3
- #include <torch/all.h>
4
- #include <c10/cuda/CUDAGuard.h>
5
-
6
- #include <cmath>
7
-
8
- #include "cuda_compat.h"
9
- #include "dispatch_utils.h"
10
- #include "assert_utils.h"
11
- #include "atomic_utils.h"
12
- #include "block_reduce.h"
13
-
14
- namespace motif {
15
-
16
- template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
- __global__ void poly_norm_naive_kernel(
18
- scalar_t* __restrict__ out, // [..., d]
19
- const scalar_t* __restrict__ input, // [..., d]
20
- const scalar_t* __restrict__ weight, // [3]
21
- const scalar_t* __restrict__ bias, // [1]
22
- const float eps,
23
- const int d
24
- ) {
25
- const int64_t token_idx = blockIdx.x;
26
-
27
- acc_t sum = 0.0f;
28
- acc_t sum_square = 0.0f;
29
- acc_t sum_cube = 0.0f;
30
-
31
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
32
- acc_t x = input[token_idx * d + idx];
33
- sum += pow(x, 2.0f);
34
- sum_square += pow(x, 4.0f);
35
- sum_cube += pow(x, 6.0f);
36
- }
37
-
38
- __shared__ acc_t shared[BLOCK_SIZE];
39
-
40
- acc_t mean = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum, d) / d;
41
- acc_t mean_square = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
42
- acc_t mean_cube = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_cube, d) / d;
43
-
44
- acc_t w0 = weight[0];
45
- acc_t w1 = weight[1];
46
- acc_t w2 = weight[2];
47
- acc_t b = bias[0];
48
-
49
- acc_t divisor = sqrt(mean + eps);
50
- acc_t divisor_square = sqrt(mean_square + eps);
51
- acc_t divisor_cube = sqrt(mean_cube + eps);
52
-
53
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
54
- acc_t x = input[token_idx * d + idx];
55
- acc_t x_square = pow(x, 2.0f);
56
- acc_t x_cube = pow(x, 3.0f);
57
- out[token_idx * d + idx] = w2 * x / divisor +
58
- w1 * x_square / divisor_square +
59
- w0 * x_cube / divisor_cube + b;
60
- }
61
- }
62
-
63
- template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
64
- __global__ void poly_norm_naive_backward_kernel(
65
- scalar_t* __restrict__ input_grad, // [..., d]
66
- acc_t* __restrict__ temp_weight_grad, // [..., 3]
67
- const scalar_t* __restrict__ output_grad, // [..., d]
68
- const scalar_t* __restrict__ input, // [..., d]
69
- const scalar_t* __restrict__ weight, // [3]
70
- const float eps,
71
- const int d
72
- ) {
73
- const int64_t token_idx = blockIdx.x;
74
-
75
- acc_t w0 = weight[0];
76
- acc_t w1 = weight[1];
77
- acc_t w2 = weight[2];
78
-
79
- acc_t sum_2 = 0.0f;
80
- acc_t sum_4 = 0.0f;
81
- acc_t sum_6 = 0.0f;
82
-
83
- acc_t sum_dx_1 = 0.0f;
84
- acc_t sum_dx_2 = 0.0f;
85
- acc_t sum_dx_3 = 0.0f;
86
-
87
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
88
- acc_t dy = output_grad[token_idx * d + idx];
89
-
90
- acc_t x_1 = input[token_idx * d + idx];
91
- acc_t x_2 = x_1 * x_1;
92
- acc_t x_3 = x_2 * x_1;
93
- acc_t x_4 = x_2 * x_2;
94
- acc_t x_6 = x_3 * x_3;
95
-
96
- sum_2 += x_2;
97
- sum_4 += x_4;
98
- sum_6 += x_6;
99
-
100
- sum_dx_1 += dy * x_1;
101
- sum_dx_2 += dy * x_2;
102
- sum_dx_3 += dy * x_3;
103
- }
104
-
105
- __shared__ acc_t shared[BLOCK_SIZE];
106
-
107
- acc_t mean_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_2, d) / d + eps;
108
- acc_t mean_4 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_4, d) / d + eps;
109
- acc_t mean_6 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_6, d) / d + eps;
110
-
111
- sum_dx_1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_1, d);
112
- sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
113
- sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
114
-
115
- acc_t _mean_2 = powf(mean_2, -1.5);
116
- acc_t _mean_4 = powf(mean_4, -1.5);
117
- acc_t _mean_6 = powf(mean_6, -1.5);
118
-
119
- acc_t sq_mean_2 = sqrtf(mean_2);
120
- acc_t sq_mean_4 = sqrtf(mean_4);
121
- acc_t sq_mean_6 = sqrtf(mean_6);
122
-
123
- acc_t sum_dw0 = 0;
124
- acc_t sum_dw1 = 0;
125
- acc_t sum_dw2 = 0;
126
-
127
- for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
128
- acc_t dy = output_grad[token_idx * d + idx];
129
- acc_t x_1 = input[token_idx * d + idx];
130
- acc_t x_2 = x_1 * x_1;
131
- acc_t x_3 = x_2 * x_1;
132
-
133
- acc_t dx_3 =
134
- _mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
135
- acc_t dx_2 =
136
- _mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
137
- acc_t dx_1 =
138
- _mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
139
-
140
- if (input_grad) {
141
- input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
142
- }
143
-
144
- sum_dw0 += dy * (x_3 / sq_mean_6);
145
- sum_dw1 += dy * (x_2 / sq_mean_4);
146
- sum_dw2 += dy * (x_1 / sq_mean_2);
147
- }
148
-
149
- if (temp_weight_grad) {
150
- sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
151
- sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
152
- sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
153
-
154
- if (threadIdx.x == 0) {
155
- temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
156
- temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
157
- temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
158
- }
159
- }
160
- }
161
-
162
- } // namespace motif
163
-
164
-
165
- void poly_norm_naive(torch::Tensor& out, // [..., d]
166
- const torch::Tensor& input, // [..., d]
167
- const torch::Tensor& weight, // [3]
168
- const torch::Tensor& bias, // [1]
169
- double eps)
170
- {
171
- AssertTensorShapeEqual(input, out, "input", "out");
172
- AssertTensorNotNull(weight, "weight");
173
- AssertTensorNotNull(bias, "bias");
174
- // TODO shape check
175
-
176
- constexpr int BLOCK_SIZE = 256;
177
-
178
- int d = input.size(-1);
179
- int64_t num_tokens = input.numel() / input.size(-1);
180
- dim3 grid(num_tokens);
181
- dim3 block(BLOCK_SIZE);
182
-
183
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
184
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
185
- MOTIF_DISPATCH_FLOATING_TYPES(
186
- input.scalar_type(), "poly_norm_naive_kernel", [&] {
187
- motif::poly_norm_naive_kernel<scalar_t, float, BLOCK_SIZE>
188
- <<<grid, block, 0, stream>>>(
189
- out.data_ptr<scalar_t>(),
190
- input.data_ptr<scalar_t>(),
191
- weight.data_ptr<scalar_t>(),
192
- bias.data_ptr<scalar_t>(), eps, d);
193
- }
194
- );
195
- }
196
-
197
- void poly_norm_naive_backward(
198
- torch::Tensor& input_grad, // [..., d]
199
- torch::Tensor& weight_grad, // [..., d]
200
- torch::Tensor& bias_grad, // [..., d]
201
- const torch::Tensor& output_grad, // [3]
202
- const torch::Tensor& input, // [3]
203
- const torch::Tensor& weight, // [3]
204
- double eps) {
205
- AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
206
- AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
207
- AssertTensorNotNull(weight, "weight");
208
- // TODO shape check
209
- // weight_grad, bias_grad and input_grad can be nullable
210
-
211
- constexpr int BLOCK_SIZE = 256;
212
-
213
- int d = input.size(-1);
214
- int64_t num_tokens = input.numel() / input.size(-1);
215
- dim3 grid(num_tokens);
216
- dim3 block(BLOCK_SIZE);
217
-
218
- torch::Tensor temp_weight_grad =
219
- torch::empty({num_tokens, 3},
220
- input.options().dtype(torch::kFloat));
221
-
222
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
223
-
224
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
225
- MOTIF_DISPATCH_FLOATING_TYPES(
226
- input.scalar_type(), "poly_norm_naive_backward_kernel", [&] {
227
- motif::poly_norm_naive_backward_kernel<scalar_t, float, BLOCK_SIZE>
228
- <<<grid, block, 0, stream>>>(
229
- input_grad.data_ptr<scalar_t>(),
230
- temp_weight_grad.data_ptr<float>(),
231
- output_grad.data_ptr<scalar_t>(),
232
- input.data_ptr<scalar_t>(),
233
- weight.data_ptr<scalar_t>(),
234
- eps, d);
235
- }
236
- );
237
-
238
- if (bias_grad.defined()) {
239
- at::sum_out(bias_grad, output_grad);
240
- bias_grad.resize_({1});
241
- }
242
-
243
- if (weight_grad.defined()) {
244
- at::sum_out(weight_grad, temp_weight_grad, {0});
245
- }
246
- }