iamwyldecat commited on
Commit
704692b
·
1 Parent(s): 883cc1c

fix(poly-norm): calc param grad explicitly

Browse files
activation/activation_kernels.cu CHANGED
@@ -1,4 +1,5 @@
1
  #include <ATen/cuda/CUDAContext.h>
 
2
  #include <torch/all.h>
3
  #include <c10/cuda/CUDAGuard.h>
4
 
@@ -78,8 +79,7 @@ __global__ void poly_norm_kernel(
78
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
79
  __global__ void poly_norm_backward_kernel(
80
  scalar_t* __restrict__ input_grad, // [..., d]
81
- scalar_t* __restrict__ weight_grad, // [3]
82
- scalar_t* __restrict__ bias_grad, // [1]
83
  const scalar_t* __restrict__ output_grad, // [..., d]
84
  const scalar_t* __restrict__ input, // [..., d]
85
  const scalar_t* __restrict__ weight, // [3]
@@ -128,14 +128,17 @@ __global__ void poly_norm_backward_kernel(
128
  sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
129
  sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
130
 
131
- acc_t sq_mean_2 = powf(mean_2, -1.5);
132
- acc_t sq_mean_4 = powf(mean_4, -1.5);
133
- acc_t sq_mean_6 = powf(mean_6, -1.5);
 
 
 
 
134
 
135
  acc_t sum_dw0 = 0;
136
  acc_t sum_dw1 = 0;
137
  acc_t sum_dw2 = 0;
138
- acc_t sum_db = 0;
139
 
140
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
141
  acc_t dy = output_grad[token_idx * d + idx];
@@ -144,38 +147,30 @@ __global__ void poly_norm_backward_kernel(
144
  acc_t x_3 = x_2 * x_1;
145
 
146
  acc_t dx_3 =
147
- sq_mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
148
  acc_t dx_2 =
149
- sq_mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
150
  acc_t dx_1 =
151
- sq_mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
152
 
153
  if (input_grad) {
154
  input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
155
  }
156
 
157
- sum_dw0 += dy * (x_3 / sqrt(mean_6));
158
- sum_dw1 += dy * (x_2 / sqrt(mean_4));
159
- sum_dw2 += dy * (x_1 / sqrt(mean_2));
160
- sum_db += dy;
161
  }
162
 
163
- if (weight_grad) {
164
  sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
165
  sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
166
  sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
167
 
168
  if (threadIdx.x == 0) {
169
- atomic_add(&weight_grad[0], sum_dw0);
170
- atomic_add(&weight_grad[1], sum_dw1);
171
- atomic_add(&weight_grad[2], sum_dw2);
172
- }
173
- }
174
-
175
- if (bias_grad) {
176
- sum_db = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_db, d);
177
- if (threadIdx.x == 0) {
178
- atomic_add(&bias_grad[0], sum_db);
179
  }
180
  }
181
  }
@@ -236,14 +231,11 @@ void poly_norm_backward(
236
  dim3 grid(num_tokens);
237
  dim3 block(BLOCK_SIZE);
238
 
239
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
 
240
 
241
- if (weight_grad.defined()) {
242
- cudaMemsetAsync(weight_grad.data_ptr(), 0, weight_grad.numel() * weight_grad.element_size(), stream);
243
- }
244
- if (bias_grad.defined()) {
245
- cudaMemsetAsync(bias_grad.data_ptr(), 0, bias_grad.numel() * bias_grad.element_size(), stream);
246
- }
247
 
248
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
249
  MOTIF_DISPATCH_FLOATING_TYPES(
@@ -251,12 +243,15 @@ void poly_norm_backward(
251
  motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
252
  <<<grid, block, 0, stream>>>(
253
  input_grad.data_ptr<scalar_t>(),
254
- weight_grad.data_ptr<scalar_t>(),
255
- bias_grad.data_ptr<scalar_t>(),
256
  output_grad.data_ptr<scalar_t>(),
257
  input.data_ptr<scalar_t>(),
258
  weight.data_ptr<scalar_t>(),
259
  eps, d);
260
  }
261
  );
 
 
 
 
262
  }
 
1
  #include <ATen/cuda/CUDAContext.h>
2
+ #include <ATen/Functions.h>
3
  #include <torch/all.h>
4
  #include <c10/cuda/CUDAGuard.h>
5
 
 
79
  template <typename scalar_t, typename acc_t, int BLOCK_SIZE>
80
  __global__ void poly_norm_backward_kernel(
81
  scalar_t* __restrict__ input_grad, // [..., d]
82
+ acc_t* __restrict__ temp_weight_grad, // [..., 3]
 
83
  const scalar_t* __restrict__ output_grad, // [..., d]
84
  const scalar_t* __restrict__ input, // [..., d]
85
  const scalar_t* __restrict__ weight, // [3]
 
128
  sum_dx_2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_2, d);
129
  sum_dx_3 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dx_3, d);
130
 
131
+ acc_t _mean_2 = powf(mean_2, -1.5);
132
+ acc_t _mean_4 = powf(mean_4, -1.5);
133
+ acc_t _mean_6 = powf(mean_6, -1.5);
134
+
135
+ acc_t sq_mean_2 = sqrtf(mean_2);
136
+ acc_t sq_mean_4 = sqrtf(mean_4);
137
+ acc_t sq_mean_6 = sqrtf(mean_6);
138
 
139
  acc_t sum_dw0 = 0;
140
  acc_t sum_dw1 = 0;
141
  acc_t sum_dw2 = 0;
 
142
 
143
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
144
  acc_t dy = output_grad[token_idx * d + idx];
 
147
  acc_t x_3 = x_2 * x_1;
148
 
149
  acc_t dx_3 =
150
+ _mean_6 * 3 * x_2 * (dy * mean_6 - x_3 * sum_dx_3 / d) * w0;
151
  acc_t dx_2 =
152
+ _mean_4 * 2 * x_1 * (dy * mean_4 - x_2 * sum_dx_2 / d) * w1;
153
  acc_t dx_1 =
154
+ _mean_2 * (dy * mean_2 - x_1 * sum_dx_1 / d) * w2;
155
 
156
  if (input_grad) {
157
  input_grad[token_idx * d + idx] = dx_1 + dx_2 + dx_3;
158
  }
159
 
160
+ sum_dw0 += dy * (x_3 / sq_mean_6);
161
+ sum_dw1 += dy * (x_2 / sq_mean_4);
162
+ sum_dw2 += dy * (x_1 / sq_mean_2);
 
163
  }
164
 
165
+ if (temp_weight_grad) {
166
  sum_dw0 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw0, d);
167
  sum_dw1 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw1, d);
168
  sum_dw2 = _block_reduce_sum<acc_t, BLOCK_SIZE>(shared, sum_dw2, d);
169
 
170
  if (threadIdx.x == 0) {
171
+ temp_weight_grad[token_idx * 3 + 0] = sum_dw0;
172
+ temp_weight_grad[token_idx * 3 + 1] = sum_dw1;
173
+ temp_weight_grad[token_idx * 3 + 2] = sum_dw2;
 
 
 
 
 
 
 
174
  }
175
  }
176
  }
 
231
  dim3 grid(num_tokens);
232
  dim3 block(BLOCK_SIZE);
233
 
234
+ torch::Tensor temp_weight_grad =
235
+ torch::empty({num_tokens, 3},
236
+ input.options().dtype(torch::kFloat));
237
 
238
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 
 
 
 
 
239
 
240
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
241
  MOTIF_DISPATCH_FLOATING_TYPES(
 
243
  motif::poly_norm_backward_kernel<scalar_t, float, BLOCK_SIZE>
244
  <<<grid, block, 0, stream>>>(
245
  input_grad.data_ptr<scalar_t>(),
246
+ temp_weight_grad.data_ptr<float>(),
 
247
  output_grad.data_ptr<scalar_t>(),
248
  input.data_ptr<scalar_t>(),
249
  weight.data_ptr<scalar_t>(),
250
  eps, d);
251
  }
252
  );
253
+
254
+ at::sum_out(bias_grad, output_grad);
255
+ at::sum_out(weight_grad, temp_weight_grad, {0});
256
+ bias_grad.resize_({1});
257
  }
activation/atomic_utils.h CHANGED
@@ -27,19 +27,19 @@ __device__ inline void atomic_add<c10::BFloat16, float>(c10::BFloat16* _address,
27
 
28
  size_t offset = (size_t)address & 0x2;
29
  volatile uint16_t* address_as_short =
30
- reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
31
  volatile uint32_t* address_as_uint =
32
- reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
33
- bool is_32bit_aligned = offset == 0;
34
 
35
- uint32_t current = address_as_uint[0];
36
- uint32_t expected;
37
 
38
  do {
39
  expected = current;
40
- c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
41
- c10::BFloat16 next_bf16 = current_bf16 + value;
42
- uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
43
  : (current & 0x0000ffff) | (next_bf16.x << 16);
44
  current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
45
  } while (current != expected);
@@ -51,19 +51,19 @@ __device__ inline void atomic_add<c10::Half, float>(c10::Half* _address, float v
51
 
52
  size_t offset = (size_t)address & 0x2;
53
  volatile uint16_t* address_as_short =
54
- reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
55
  volatile uint32_t* address_as_uint =
56
- reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
57
- bool is_32bit_aligned = offset == 0;
58
 
59
- uint32_t current = address_as_uint[0];
60
- uint32_t expected;
61
 
62
  do {
63
  expected = current;
64
- c10::Half current_half(address_as_short[0], c10::Half::from_bits());
65
- c10::Half next_half = current_half + value;
66
- uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
67
  : (current & 0x0000ffff) | (next_half.x << 16);
68
  current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
69
  } while (current != expected);
 
27
 
28
  size_t offset = (size_t)address & 0x2;
29
  volatile uint16_t* address_as_short =
30
+ reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
31
  volatile uint32_t* address_as_uint =
32
+ reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
33
+ bool is_32bit_aligned = offset == 0;
34
 
35
+ uint32_t current = address_as_uint[0];
36
+ uint32_t expected;
37
 
38
  do {
39
  expected = current;
40
+ c10::BFloat16 current_bf16(address_as_short[0], c10::BFloat16::from_bits());
41
+ c10::BFloat16 next_bf16 = current_bf16 + value;
42
+ uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_bf16.x
43
  : (current & 0x0000ffff) | (next_bf16.x << 16);
44
  current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
45
  } while (current != expected);
 
51
 
52
  size_t offset = (size_t)address & 0x2;
53
  volatile uint16_t* address_as_short =
54
+ reinterpret_cast<volatile uint16_t*>(reinterpret_cast<volatile char*>(address));
55
  volatile uint32_t* address_as_uint =
56
+ reinterpret_cast<volatile uint*>(reinterpret_cast<volatile char*>(address) - offset);
57
+ bool is_32bit_aligned = offset == 0;
58
 
59
+ uint32_t current = address_as_uint[0];
60
+ uint32_t expected;
61
 
62
  do {
63
  expected = current;
64
+ c10::Half current_half(address_as_short[0], c10::Half::from_bits());
65
+ c10::Half next_half = current_half + value;
66
+ uint32_t next = is_32bit_aligned ? (current & 0xffff0000) | next_half.x
67
  : (current & 0x0000ffff) | (next_half.x << 16);
68
  current = atomicCAS(const_cast<uint32_t*>(address_as_uint), expected, next);
69
  } while (current != expected);
build/flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1747919133,
77
+ "narHash": "sha256-VvF1naQOvv7yulQ5/cDiaxkNxlh1Y84QMZnderv1szk=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "9c71e026d6c7c8588ef85a5f7c77f57d598e038c",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1748620233,
102
+ "narHash": "sha256-VULm9HgGXvo3pyfsPy3SOhoqgkuqbGSaSemvzNUbdIU=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "da3340e5b3cbb6086600420f4814b033395788d1",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1747820358,
117
+ "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
+ "owner": "danieldk",
119
+ "repo": "nixpkgs",
120
+ "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "danieldk",
125
+ "ref": "cudatoolkit-12.9-kernel-builder",
126
+ "repo": "nixpkgs",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
build/flake.nix ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Torch kernel extension";
3
+ inputs = {
4
+ kernel-builder.url = "github:huggingface/kernel-builder";
5
+ };
6
+ outputs = { self, kernel-builder, }:
7
+ kernel-builder.lib.genFlakeOutputs {
8
+ path = ./.;
9
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
10
+ };
11
+ }
build/torch26-cxx11-rocm62-x86_64-linux/activation/{_activation_32c2bde_dirty.abi3.so → _activation_883cc1c_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:151887149f1c434d7778bf213a758748e3fe15e3af5108c9a90fd679416c5ebe
3
- size 2425872
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9d74188efdcb10158b338cf363749494f86e9712797722310f0a6ac5310efdd
3
+ size 2401160
build/torch26-cxx11-rocm62-x86_64-linux/activation/_activation_f72121c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b358f5be0dc4f1c1d7198ca4417c74cf9626f678b89772e178154acbaee1476a
3
- size 2460736
 
 
 
 
build/torch26-cxx11-rocm62-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_32c2bde_dirty
3
- ops = torch.ops._activation_32c2bde_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_32c2bde_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_883cc1c_dirty
3
+ ops = torch.ops._activation_883cc1c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_883cc1c_dirty::{op_name}"
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_32c2bde_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c1f586de8406ba777d2d80bbcad8cc711032ef3971c1e963c7d31845c25b28c8
3
- size 2404376
 
 
 
 
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_552d415_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:eecd2db703418250f22fbccdda97d8f45a15d3a8d34d1c6be1f0a1e3a7076990
3
- size 2447480
 
 
 
 
build/{torch26-cxx11-rocm62-x86_64-linux/activation/_activation_552d415_dirty.abi3.so → torch27-cxx11-rocm63-x86_64-linux/activation/_activation_883cc1c_dirty.abi3.so} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:98d5b88d4ef1ae793dc71c798416cc4d71b8d180a0a4627531ad7ab78116247d
3
- size 2464880
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:719fc6521c0824b253cb11ea9e564ef7835e2102e5bc6399cfdb69203d6d5c26
3
+ size 2395176
build/torch27-cxx11-rocm63-x86_64-linux/activation/_activation_f72121c_dirty.abi3.so DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d322fb12e4bd5eab4700783a6cfac4a8a9f9f21c7c61fd2ddb47253da8e182f1
3
- size 2447176
 
 
 
 
build/torch27-cxx11-rocm63-x86_64-linux/activation/_ops.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
- from . import _activation_32c2bde_dirty
3
- ops = torch.ops._activation_32c2bde_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
- return f"_activation_32c2bde_dirty::{op_name}"
 
1
  import torch
2
+ from . import _activation_883cc1c_dirty
3
+ ops = torch.ops._activation_883cc1c_dirty
4
 
5
  def add_op_namespace_prefix(op_name: str):
6
  """
7
  Prefix op by namespace.
8
  """
9
+ return f"_activation_883cc1c_dirty::{op_name}"
tests/pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ log_cli = true
3
+ log_cli_level = INFO