Commit
·
704692b
1
Parent(s):
883cc1c
fix(poly-norm): calc param grad explicitly
Browse files- activation/activation_kernels.cu +28 -33
- activation/atomic_utils.h +16 -16
- build/flake.lock +168 -0
- build/flake.nix +11 -0
- build/torch26-cxx11-rocm62-x86_64-linux/activation/{_activation_32c2bde_dirty.abi3.so → _activation_883cc1c_dirty.abi3.so} +2 -2
- build/torch26-cxx11-rocm62-x86_64-linux/activation/_activation_f72121c_dirty.abi3.so +0 -3
- build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py +3 -3
- build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_32c2bde_dirty.abi3.so +0 -3
- build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_552d415_dirty.abi3.so +0 -3
- build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_552d415_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_883cc1c_dirty.abi3.so} +2 -2
- build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f72121c_dirty.abi3.so +0 -3
- build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py +3 -3
- tests/pytest.ini +3 -0
activation/activation_kernels.cu
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
#include <ATen/cuda/CUDAContext.h>
|
|
|
2 |
#include <torch/all.h>
|
3 |
#include <c10/cuda/CUDAGuard.h>
|
4 |
|
@@ -78,8 +79,7 @@ __global__ void poly_norm_kernel(
|
|
78 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
79 |
__global__ void poly_norm_backward_kernel(
|
80 |
scalar_t* __restrict__ input_grad, // [..., d]
|
81 |
-
|
82 |
-
scalar_t* __restrict__ bias_grad, // [1]
|
83 |
const scalar_t* __restrict__ output_grad, // [..., d]
|
84 |
const scalar_t* __restrict__ input, // [..., d]
|
85 |
const scalar_t* __restrict__ weight, // [3]
|
@@ -128,14 +128,17 @@ __global__ void poly_norm_backward_kernel(
|
|
128 |
sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
|
129 |
sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
|
130 |
|
131 |
-
acc_t
|
132 |
-
acc_t
|
133 |
-
acc_t
|
|
|
|
|
|
|
|
|
134 |
|
135 |
acc_t sum_dw0 = 0;
|
136 |
acc_t sum_dw1 = 0;
|
137 |
acc_t sum_dw2 = 0;
|
138 |
-
acc_t sum_db = 0;
|
139 |
|
140 |
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
141 |
acc_t dy = output_grad[token_idx * d + idx];
|
@@ -144,38 +147,30 @@ __global__ void poly_norm_backward_kernel(
|
|
144 |
acc_t x_3 = x_2 * x_1;
|
145 |
|
146 |
acc_t dx_3 =
|
147 |
-
|
148 |
acc_t dx_2 =
|
149 |
-
|
150 |
acc_t dx_1 =
|
151 |
-
|
152 |
|
153 |
if (input_grad) {
|
154 |
input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
|
155 |
}
|
156 |
|
157 |
-
sum_dw0 += dy * (x_3 /
|
158 |
-
sum_dw1 += dy * (x_2 /
|
159 |
-
sum_dw2 += dy * (x_1 /
|
160 |
-
sum_db += dy;
|
161 |
}
|
162 |
|
163 |
-
if (
|
164 |
sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
|
165 |
sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
|
166 |
sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
|
167 |
|
168 |
if (threadIdx.x == 0) {
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
}
|
173 |
-
}
|
174 |
-
|
175 |
-
if (bias_grad) {
|
176 |
-
sum_db = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_db, d);
|
177 |
-
if (threadIdx.x == 0) {
|
178 |
-
atomic_add(&bias_grad[0], sum_db);
|
179 |
}
|
180 |
}
|
181 |
}
|
@@ -236,14 +231,11 @@ void poly_norm_backward(
|
|
236 |
dim3 grid(num_tokens);
|
237 |
dim3 block(BLOCK_SIZE);
|
238 |
|
239 |
-
|
|
|
|
|
240 |
|
241 |
-
|
242 |
-
cudaMemsetAsync(weight_grad.data_ptr(), 0, weight_grad.numel() * weight_grad.element_size(), stream);
|
243 |
-
}
|
244 |
-
if (bias_grad.defined()) {
|
245 |
-
cudaMemsetAsync(bias_grad.data_ptr(), 0, bias_grad.numel() * bias_grad.element_size(), stream);
|
246 |
-
}
|
247 |
|
248 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
249 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
@@ -251,12 +243,15 @@ void poly_norm_backward(
|
|
251 |
motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
|
252 |
<<<grid, block, 0, stream>>>(
|
253 |
input_grad.data_ptr<scalar_t>(),
|
254 |
-
|
255 |
-
bias_grad.data_ptr<scalar_t>(),
|
256 |
output_grad.data_ptr<scalar_t>(),
|
257 |
input.data_ptr<scalar_t>(),
|
258 |
weight.data_ptr<scalar_t>(),
|
259 |
eps, d);
|
260 |
}
|
261 |
);
|
|
|
|
|
|
|
|
|
262 |
}
|
|
|
1 |
#include <ATen/cuda/CUDAContext.h>
|
2 |
+
#include <ATen/Functions.h>
|
3 |
#include <torch/all.h>
|
4 |
#include <c10/cuda/CUDAGuard.h>
|
5 |
|
|
|
79 |
template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
|
80 |
__global__ void poly_norm_backward_kernel(
|
81 |
scalar_t* __restrict__ input_grad, // [..., d]
|
82 |
+
acc_t* __restrict__ temp_weight_grad, // [..., 3]
|
|
|
83 |
const scalar_t* __restrict__ output_grad, // [..., d]
|
84 |
const scalar_t* __restrict__ input, // [..., d]
|
85 |
const scalar_t* __restrict__ weight, // [3]
|
|
|
128 |
sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
|
129 |
sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
|
130 |
|
131 |
+
acc_t _mean_2 = powf(mean_2, -1.5);
|
132 |
+
acc_t _mean_4 = powf(mean_4, -1.5);
|
133 |
+
acc_t _mean_6 = powf(mean_6, -1.5);
|
134 |
+
|
135 |
+
acc_t sq_mean_2 = sqrtf(mean_2);
|
136 |
+
acc_t sq_mean_4 = sqrtf(mean_4);
|
137 |
+
acc_t sq_mean_6 = sqrtf(mean_6);
|
138 |
|
139 |
acc_t sum_dw0 = 0;
|
140 |
acc_t sum_dw1 = 0;
|
141 |
acc_t sum_dw2 = 0;
|
|
|
142 |
|
143 |
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
144 |
acc_t dy = output_grad[token_idx * d + idx];
|
|
|
147 |
acc_t x_3 = x_2 * x_1;
|
148 |
|
149 |
acc_t dx_3 =
|
150 |
+
_mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
|
151 |
acc_t dx_2 =
|
152 |
+
_mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
|
153 |
acc_t dx_1 =
|
154 |
+
_mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
|
155 |
|
156 |
if (input_grad) {
|
157 |
input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
|
158 |
}
|
159 |
|
160 |
+
sum_dw0 += dy * (x_3 / sq_mean_6);
|
161 |
+
sum_dw1 += dy * (x_2 / sq_mean_4);
|
162 |
+
sum_dw2 += dy * (x_1 / sq_mean_2);
|
|
|
163 |
}
|
164 |
|
165 |
+
if (temp_weight_grad) {
|
166 |
sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
|
167 |
sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
|
168 |
sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
|
169 |
|
170 |
if (threadIdx.x == 0) {
|
171 |
+
temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
|
172 |
+
temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
|
173 |
+
temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
}
|
175 |
}
|
176 |
}
|
|
|
231 |
dim3 grid(num_tokens);
|
232 |
dim3 block(BLOCK_SIZE);
|
233 |
|
234 |
+
torch::Tensor temp_weight_grad =
|
235 |
+
torch::empty({num_tokens, 3},
|
236 |
+
input.options().dtype(torch::kFloat));
|
237 |
|
238 |
+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
|
|
|
|
|
|
|
239 |
|
240 |
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
241 |
MOTIF_DISPATCH_FLOATING_TYPES(
|
|
|
243 |
motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
|
244 |
<<<grid, block, 0, stream>>>(
|
245 |
input_grad.data_ptr<scalar_t>(),
|
246 |
+
temp_weight_grad.data_ptr<float>(),
|
|
|
247 |
output_grad.data_ptr<scalar_t>(),
|
248 |
input.data_ptr<scalar_t>(),
|
249 |
weight.data_ptr<scalar_t>(),
|
250 |
eps, d);
|
251 |
}
|
252 |
);
|
253 |
+
|
254 |
+
at::sum_out(bias_grad, output_grad);
|
255 |
+
at::sum_out(weight_grad, temp_weight_grad, {0});
|
256 |
+
bias_grad.resize_({1});
|
257 |
}
|
activation/atomic_utils.h
CHANGED
@@ -27,19 +27,19 @@ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address,
|
|
27 |
|
28 |
size_t offset = (size_t)address & 0x2;
|
29 |
volatile uint16_t* address_as_short =
|
30 |
-
|
31 |
volatile uint32_t* address_as_uint =
|
32 |
-
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
|
38 |
do {
|
39 |
expected = current;
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
: (current & 0x0000ffff) | (next_bf16.x << 16);
|
44 |
current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
|
45 |
} while (current != expected);
|
@@ -51,19 +51,19 @@ __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float v
|
|
51 |
|
52 |
size_t offset = (size_t)address & 0x2;
|
53 |
volatile uint16_t* address_as_short =
|
54 |
-
|
55 |
volatile uint32_t* address_as_uint =
|
56 |
-
|
57 |
-
|
58 |
|
59 |
-
|
60 |
-
|
61 |
|
62 |
do {
|
63 |
expected = current;
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
: (current & 0x0000ffff) | (next_half.x << 16);
|
68 |
current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
|
69 |
} while (current != expected);
|
|
|
27 |
|
28 |
size_t offset = (size_t)address & 0x2;
|
29 |
volatile uint16_t* address_as_short =
|
30 |
+
reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
|
31 |
volatile uint32_t* address_as_uint =
|
32 |
+
reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
|
33 |
+
bool is_32bit_aligned = offset == 0;
|
34 |
|
35 |
+
uint32_t current = address_as_uint[0];
|
36 |
+
uint32_t expected;
|
37 |
|
38 |
do {
|
39 |
expected = current;
|
40 |
+
c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
|
41 |
+
c10::BFloat16 next_bf16 = current_bf16 + value;
|
42 |
+
uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
|
43 |
: (current & 0x0000ffff) | (next_bf16.x << 16);
|
44 |
current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
|
45 |
} while (current != expected);
|
|
|
51 |
|
52 |
size_t offset = (size_t)address & 0x2;
|
53 |
volatile uint16_t* address_as_short =
|
54 |
+
reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
|
55 |
volatile uint32_t* address_as_uint =
|
56 |
+
reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
|
57 |
+
bool is_32bit_aligned = offset == 0;
|
58 |
|
59 |
+
uint32_t current = address_as_uint[0];
|
60 |
+
uint32_t expected;
|
61 |
|
62 |
do {
|
63 |
expected = current;
|
64 |
+
c10::Half current_half(address_as_short[0], c10::Half::from_bits());
|
65 |
+
c10::Half next_half = current_half + value;
|
66 |
+
uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
|
67 |
: (current & 0x0000ffff) | (next_half.x << 16);
|
68 |
current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
|
69 |
} while (current != expected);
|
build/flake.lock
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nodes": {
|
3 |
+
"flake-compat": {
|
4 |
+
"locked": {
|
5 |
+
"lastModified": 1747046372,
|
6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
7 |
+
"owner": "edolstra",
|
8 |
+
"repo": "flake-compat",
|
9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
10 |
+
"type": "github"
|
11 |
+
},
|
12 |
+
"original": {
|
13 |
+
"owner": "edolstra",
|
14 |
+
"repo": "flake-compat",
|
15 |
+
"type": "github"
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"flake-compat_2": {
|
19 |
+
"locked": {
|
20 |
+
"lastModified": 1733328505,
|
21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
22 |
+
"owner": "edolstra",
|
23 |
+
"repo": "flake-compat",
|
24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
25 |
+
"type": "github"
|
26 |
+
},
|
27 |
+
"original": {
|
28 |
+
"owner": "edolstra",
|
29 |
+
"repo": "flake-compat",
|
30 |
+
"type": "github"
|
31 |
+
}
|
32 |
+
},
|
33 |
+
"flake-utils": {
|
34 |
+
"inputs": {
|
35 |
+
"systems": "systems"
|
36 |
+
},
|
37 |
+
"locked": {
|
38 |
+
"lastModified": 1731533236,
|
39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
40 |
+
"owner": "numtide",
|
41 |
+
"repo": "flake-utils",
|
42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
43 |
+
"type": "github"
|
44 |
+
},
|
45 |
+
"original": {
|
46 |
+
"owner": "numtide",
|
47 |
+
"repo": "flake-utils",
|
48 |
+
"type": "github"
|
49 |
+
}
|
50 |
+
},
|
51 |
+
"flake-utils_2": {
|
52 |
+
"inputs": {
|
53 |
+
"systems": "systems_2"
|
54 |
+
},
|
55 |
+
"locked": {
|
56 |
+
"lastModified": 1731533236,
|
57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
58 |
+
"owner": "numtide",
|
59 |
+
"repo": "flake-utils",
|
60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
61 |
+
"type": "github"
|
62 |
+
},
|
63 |
+
"original": {
|
64 |
+
"owner": "numtide",
|
65 |
+
"repo": "flake-utils",
|
66 |
+
"type": "github"
|
67 |
+
}
|
68 |
+
},
|
69 |
+
"hf-nix": {
|
70 |
+
"inputs": {
|
71 |
+
"flake-compat": "flake-compat_2",
|
72 |
+
"flake-utils": "flake-utils_2",
|
73 |
+
"nixpkgs": "nixpkgs"
|
74 |
+
},
|
75 |
+
"locked": {
|
76 |
+
"lastModified": 1747919133,
|
77 |
+
"narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
|
78 |
+
"owner": "huggingface",
|
79 |
+
"repo": "hf-nix",
|
80 |
+
"rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
|
81 |
+
"type": "github"
|
82 |
+
},
|
83 |
+
"original": {
|
84 |
+
"owner": "huggingface",
|
85 |
+
"repo": "hf-nix",
|
86 |
+
"type": "github"
|
87 |
+
}
|
88 |
+
},
|
89 |
+
"kernel-builder": {
|
90 |
+
"inputs": {
|
91 |
+
"flake-compat": "flake-compat",
|
92 |
+
"flake-utils": "flake-utils",
|
93 |
+
"hf-nix": "hf-nix",
|
94 |
+
"nixpkgs": [
|
95 |
+
"kernel-builder",
|
96 |
+
"hf-nix",
|
97 |
+
"nixpkgs"
|
98 |
+
]
|
99 |
+
},
|
100 |
+
"locked": {
|
101 |
+
"lastModified": 1748620233,
|
102 |
+
"narHash": "sha256-VULm9HgGXvo3pyfsPy3SOhoqgkuqbGSaSemvzNUbdIU=",
|
103 |
+
"owner": "huggingface",
|
104 |
+
"repo": "kernel-builder",
|
105 |
+
"rev": "da3340e5b3cbb6086600420f4814b033395788d1",
|
106 |
+
"type": "github"
|
107 |
+
},
|
108 |
+
"original": {
|
109 |
+
"owner": "huggingface",
|
110 |
+
"repo": "kernel-builder",
|
111 |
+
"type": "github"
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"nixpkgs": {
|
115 |
+
"locked": {
|
116 |
+
"lastModified": 1747820358,
|
117 |
+
"narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
|
118 |
+
"owner": "danieldk",
|
119 |
+
"repo": "nixpkgs",
|
120 |
+
"rev": "d3c1681180717528068082103bf323147de6ab0b",
|
121 |
+
"type": "github"
|
122 |
+
},
|
123 |
+
"original": {
|
124 |
+
"owner": "danieldk",
|
125 |
+
"ref": "cudatoolkit-12.9-kernel-builder",
|
126 |
+
"repo": "nixpkgs",
|
127 |
+
"type": "github"
|
128 |
+
}
|
129 |
+
},
|
130 |
+
"root": {
|
131 |
+
"inputs": {
|
132 |
+
"kernel-builder": "kernel-builder"
|
133 |
+
}
|
134 |
+
},
|
135 |
+
"systems": {
|
136 |
+
"locked": {
|
137 |
+
"lastModified": 1681028828,
|
138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
139 |
+
"owner": "nix-systems",
|
140 |
+
"repo": "default",
|
141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
142 |
+
"type": "github"
|
143 |
+
},
|
144 |
+
"original": {
|
145 |
+
"owner": "nix-systems",
|
146 |
+
"repo": "default",
|
147 |
+
"type": "github"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"systems_2": {
|
151 |
+
"locked": {
|
152 |
+
"lastModified": 1681028828,
|
153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
154 |
+
"owner": "nix-systems",
|
155 |
+
"repo": "default",
|
156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
157 |
+
"type": "github"
|
158 |
+
},
|
159 |
+
"original": {
|
160 |
+
"owner": "nix-systems",
|
161 |
+
"repo": "default",
|
162 |
+
"type": "github"
|
163 |
+
}
|
164 |
+
}
|
165 |
+
},
|
166 |
+
"root": "root",
|
167 |
+
"version": 7
|
168 |
+
}
|
build/flake.nix
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
description = "Flake for Torch kernel extension";
|
3 |
+
inputs = {
|
4 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
5 |
+
};
|
6 |
+
outputs = { self, kernel-builder, }:
|
7 |
+
kernel-builder.lib.genFlakeOutputs {
|
8 |
+
path = ./.;
|
9 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
10 |
+
};
|
11 |
+
}
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/{_activation_32c2bde_dirty.abi3.so → _activation_883cc1c_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9d74188efdcb10158b338cf363749494f86e9712797722310f0a6ac5310efdd
|
3 |
+
size 2401160
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/_activation_f72121c_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:b358f5be0dc4f1c1d7198ca4417c74cf9626f678b89772e178154acbaee1476a
|
3 |
-
size 2460736
|
|
|
|
|
|
|
|
build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _activation_883cc1c_dirty
|
3 |
+
ops = torch.ops._activation_883cc1c_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_activation_883cc1c_dirty::{op_name}"
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_32c2bde_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:c1f586de8406ba777d2d80bbcad8cc711032ef3971c1e963c7d31845c25b28c8
|
3 |
-
size 2404376
|
|
|
|
|
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_552d415_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:eecd2db703418250f22fbccdda97d8f45a15d3a8d34d1c6be1f0a1e3a7076990
|
3 |
-
size 2447480
|
|
|
|
|
|
|
|
build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_552d415_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_883cc1c_dirty.abi3.so}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:719fc6521c0824b253cb11ea9e564ef7835e2102e5bc6399cfdb69203d6d5c26
|
3 |
+
size 2395176
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f72121c_dirty.abi3.so
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d322fb12e4bd5eab4700783a6cfac4a8a9f9f21c7c61fd2ddb47253da8e182f1
|
3 |
-
size 2447176
|
|
|
|
|
|
|
|
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
-
from . import
|
3 |
-
ops = torch.ops.
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
-
return f"
|
|
|
1 |
import torch
|
2 |
+
from . import _activation_883cc1c_dirty
|
3 |
+
ops = torch.ops._activation_883cc1c_dirty
|
4 |
|
5 |
def add_op_namespace_prefix(op_name: str):
|
6 |
"""
|
7 |
Prefix op by namespace.
|
8 |
"""
|
9 |
+
return f"_activation_883cc1c_dirty::{op_name}"
|
tests/pytest.ini
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[pytest]
|
2 |
+
log_cli = true
|
3 |
+
log_cli_level = INFO
|