ca1207 commited on
Commit
18ec195
·
1 Parent(s): 43629b7

add poly_norm_naive.cui for temp test

Browse files
Files changed (1) hide show
  1. activation/poly_norm_naive.cu +246 -0
activation/poly_norm_naive.cu ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <ATen/cuda/CUDAContext.h>
2
+ #include <ATen/Functions.h>
3
+ #include <torch/all.h>
4
+ #include <c10/cuda/CUDAGuard.h>
5
+
6
+ #include <cmath>
7
+
8
+ #include "cuda_compat.h"
9
+ #include "dispatch_utils.h"
10
+ #include "assert_utils.h"
11
+ #include "atomic_utils.h"
12
+ #include "block_reduce.h"
13
+
14
+ namespace motif {
15
+
16
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
17
+ __global__ void poly_norm_naive_kernel(
18
+ scalar_t* __restrict__ out, // [..., d]
19
+ const scalar_t* __restrict__ input, // [..., d]
20
+ const scalar_t* __restrict__ weight, // [3]
21
+ const scalar_t* __restrict__ bias, // [1]
22
+ const float eps,
23
+ const int d
24
+ ) {
25
+ const int64_t token_idx = blockIdx.x;
26
+
27
+ acc_t sum = 0.0f;
28
+ acc_t sum_square = 0.0f;
29
+ acc_t sum_cube = 0.0f;
30
+
31
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
32
+ acc_t x = input[token_idx * d + idx];
33
+ sum += pow(x, 2.0f);
34
+ sum_square += pow(x, 4.0f);
35
+ sum_cube += pow(x, 6.0f);
36
+ }
37
+
38
+ __shared__ acc_t shared[BLOCK_SIZE];
39
+
40
+ acc_t mean = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum, d) / d;
41
+ acc_t mean_square = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_square, d) / d;
42
+ acc_t mean_cube = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_cube, d) / d;
43
+
44
+ acc_t w0 = weight[0];
45
+ acc_t w1 = weight[1];
46
+ acc_t w2 = weight[2];
47
+ acc_t b = bias[0];
48
+
49
+ acc_t divisor = sqrt(mean + eps);
50
+ acc_t divisor_square = sqrt(mean_square + eps);
51
+ acc_t divisor_cube = sqrt(mean_cube + eps);
52
+
53
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
54
+ acc_t x = input[token_idx * d + idx];
55
+ acc_t x_square = pow(x, 2.0f);
56
+ acc_t x_cube = pow(x, 3.0f);
57
+ out[token_idx * d + idx] = w2 * x / divisor +
58
+ w1 * x_square / divisor_square +
59
+ w0 * x_cube / divisor_cube + b;
60
+ }
61
+ }
62
+
63
+ template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
64
+ __global__ void poly_norm_naive_backward_kernel(
65
+ scalar_t* __restrict__ input_grad, // [..., d]
66
+ acc_t* __restrict__ temp_weight_grad, // [..., 3]
67
+ const scalar_t* __restrict__ output_grad, // [..., d]
68
+ const scalar_t* __restrict__ input, // [..., d]
69
+ const scalar_t* __restrict__ weight, // [3]
70
+ const float eps,
71
+ const int d
72
+ ) {
73
+ const int64_t token_idx = blockIdx.x;
74
+
75
+ acc_t w0 = weight[0];
76
+ acc_t w1 = weight[1];
77
+ acc_t w2 = weight[2];
78
+
79
+ acc_t sum_2 = 0.0f;
80
+ acc_t sum_4 = 0.0f;
81
+ acc_t sum_6 = 0.0f;
82
+
83
+ acc_t sum_dx_1 = 0.0f;
84
+ acc_t sum_dx_2 = 0.0f;
85
+ acc_t sum_dx_3 = 0.0f;
86
+
87
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
88
+ acc_t dy = output_grad[token_idx * d + idx];
89
+
90
+ acc_t x_1 = input[token_idx * d + idx];
91
+ acc_t x_2 = x_1 * x_1;
92
+ acc_t x_3 = x_2 * x_1;
93
+ acc_t x_4 = x_2 * x_2;
94
+ acc_t x_6 = x_3 * x_3;
95
+
96
+ sum_2 += x_2;
97
+ sum_4 += x_4;
98
+ sum_6 += x_6;
99
+
100
+ sum_dx_1 += dy * x_1;
101
+ sum_dx_2 += dy * x_2;
102
+ sum_dx_3 += dy * x_3;
103
+ }
104
+
105
+ __shared__ acc_t shared[BLOCK_SIZE];
106
+
107
+ acc_t mean_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_2, d) / d + eps;
108
+ acc_t mean_4 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_4, d) / d + eps;
109
+ acc_t mean_6 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_6, d) / d + eps;
110
+
111
+ sum_dx_1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_1, d);
112
+ sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
113
+ sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
114
+
115
+ acc_t _mean_2 = powf(mean_2, -1.5);
116
+ acc_t _mean_4 = powf(mean_4, -1.5);
117
+ acc_t _mean_6 = powf(mean_6, -1.5);
118
+
119
+ acc_t sq_mean_2 = sqrtf(mean_2);
120
+ acc_t sq_mean_4 = sqrtf(mean_4);
121
+ acc_t sq_mean_6 = sqrtf(mean_6);
122
+
123
+ acc_t sum_dw0 = 0;
124
+ acc_t sum_dw1 = 0;
125
+ acc_t sum_dw2 = 0;
126
+
127
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
128
+ acc_t dy = output_grad[token_idx * d + idx];
129
+ acc_t x_1 = input[token_idx * d + idx];
130
+ acc_t x_2 = x_1 * x_1;
131
+ acc_t x_3 = x_2 * x_1;
132
+
133
+ acc_t dx_3 =
134
+ _mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
135
+ acc_t dx_2 =
136
+ _mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
137
+ acc_t dx_1 =
138
+ _mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
139
+
140
+ if (input_grad) {
141
+ input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
142
+ }
143
+
144
+ sum_dw0 += dy * (x_3 / sq_mean_6);
145
+ sum_dw1 += dy * (x_2 / sq_mean_4);
146
+ sum_dw2 += dy * (x_1 / sq_mean_2);
147
+ }
148
+
149
+ if (temp_weight_grad) {
150
+ sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
151
+ sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
152
+ sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
153
+
154
+ if (threadIdx.x == 0) {
155
+ temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
156
+ temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
157
+ temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
158
+ }
159
+ }
160
+ }
161
+
162
+ } // namespace motif
163
+
164
+
165
+ void poly_norm_naive(torch::Tensor& out, // [..., d]
166
+ const torch::Tensor& input, // [..., d]
167
+ const torch::Tensor& weight, // [3]
168
+ const torch::Tensor& bias, // [1]
169
+ double eps)
170
+ {
171
+ AssertTensorShapeEqual(input, out, "input", "out");
172
+ AssertTensorNotNull(weight, "weight");
173
+ AssertTensorNotNull(bias, "bias");
174
+ // TODO shape check
175
+
176
+ constexpr int BLOCK_SIZE = 256;
177
+
178
+ int d = input.size(-1);
179
+ int64_t num_tokens = input.numel() / input.size(-1);
180
+ dim3 grid(num_tokens);
181
+ dim3 block(BLOCK_SIZE);
182
+
183
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
184
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
185
+ MOTIF_DISPATCH_FLOATING_TYPES(
186
+ input.scalar_type(), "poly_norm_naive_kernel", [&] {
187
+ motif::poly_norm_naive_kernel<scalar_t, float, BLOCK_SIZE>
188
+ <<<grid, block, 0, stream>>>(
189
+ out.data_ptr<scalar_t>(),
190
+ input.data_ptr<scalar_t>(),
191
+ weight.data_ptr<scalar_t>(),
192
+ bias.data_ptr<scalar_t>(), eps, d);
193
+ }
194
+ );
195
+ }
196
+
197
+ void poly_norm_naive_backward(
198
+ torch::Tensor& input_grad, // [..., d]
199
+ torch::Tensor& weight_grad, // [..., d]
200
+ torch::Tensor& bias_grad, // [..., d]
201
+ const torch::Tensor& output_grad, // [3]
202
+ const torch::Tensor& input, // [3]
203
+ const torch::Tensor& weight, // [3]
204
+ double eps) {
205
+ AssertTensorShapeEqual(input, input_grad, "input", "input_grad");
206
+ AssertTensorShapeEqual(input, output_grad, "input", "output_grad");
207
+ AssertTensorNotNull(weight, "weight");
208
+ // TODO shape check
209
+ // weight_grad, bias_grad and input_grad can be nullable
210
+
211
+ constexpr int BLOCK_SIZE = 256;
212
+
213
+ int d = input.size(-1);
214
+ int64_t num_tokens = input.numel() / input.size(-1);
215
+ dim3 grid(num_tokens);
216
+ dim3 block(BLOCK_SIZE);
217
+
218
+ torch::Tensor temp_weight_grad =
219
+ torch::empty({num_tokens, 3},
220
+ input.options().dtype(torch::kFloat));
221
+
222
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
223
+
224
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
225
+ MOTIF_DISPATCH_FLOATING_TYPES(
226
+ input.scalar_type(), "poly_norm_naive_backward_kernel", [&] {
227
+ motif::poly_norm_naive_backward_kernel<scalar_t, float, BLOCK_SIZE>
228
+ <<<grid, block, 0, stream>>>(
229
+ input_grad.data_ptr<scalar_t>(),
230
+ temp_weight_grad.data_ptr<float>(),
231
+ output_grad.data_ptr<scalar_t>(),
232
+ input.data_ptr<scalar_t>(),
233
+ weight.data_ptr<scalar_t>(),
234
+ eps, d);
235
+ }
236
+ );
237
+
238
+ if (bias_grad.defined()) {
239
+ at::sum_out(bias_grad, output_grad);
240
+ bias_grad.resize_({1});
241
+ }
242
+
243
+ if (weight_grad.defined()) {
244
+ at::sum_out(weight_grad, temp_weight_grad, {0});
245
+ }
246
+ }