gagan3012 commited on
Commit
e6010fe
·
verified ·
1 Parent(s): b40d98e

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ # batch_invariant_kernel
build.toml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "batch_invariant"
3
+ universal = false
4
+
5
+ # Defines the C++ files that bind to PyTorch
6
+ [torch]
7
+ src = [
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h"
10
+ ]
11
+
12
+ # Defines the CUDA kernels
13
+ [kernel.batch_invariant_matmul]
14
+ backend = "cuda"
15
+ depends = ["torch"]
16
+ src = [
17
+ "csrc/batch_invariant.cu",
18
+ ]
csrc/batch_invariant.cu ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda_runtime.h>
3
+ #include <cublas_v2.h>
4
+ #include <cudnn.h>
5
+ #include <cmath>
6
+
7
+ // Persistent matrix multiplication kernel
8
+ __global__ void matmul_kernel_persistent(
9
+ const float *a_ptr,
10
+ const float *b_ptr,
11
+ float *c_ptr,
12
+ const float *bias_ptr,
13
+ int M, int N, int K,
14
+ int stride_am, int stride_ak,
15
+ int stride_bk, int stride_bn,
16
+ int stride_cm, int stride_cn,
17
+ int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K,
18
+ int GROUP_SIZE_M, int NUM_SMS,
19
+ bool HAS_BIAS)
20
+ {
21
+ int start_pid = blockIdx.x;
22
+ int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
23
+ int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
24
+ int k_tiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
25
+ int num_tiles = num_pid_m * num_pid_n;
26
+
27
+ int num_pid_in_group = GROUP_SIZE_M * num_pid_n;
28
+
29
+ for (int tile_id = start_pid; tile_id < num_tiles; tile_id += NUM_SMS)
30
+ {
31
+ int group_id = tile_id / num_pid_in_group;
32
+ int first_pid_m = group_id * GROUP_SIZE_M;
33
+ int group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M);
34
+ int pid_m = first_pid_m + (tile_id % group_size_m);
35
+ int pid_n = (tile_id % num_pid_in_group) / group_size_m;
36
+
37
+ int start_m = pid_m * BLOCK_SIZE_M;
38
+ int start_n = pid_n * BLOCK_SIZE_N;
39
+
40
+ // Shared memory for tile computation
41
+ __shared__ float As[16][16]; // Adjust size based on BLOCK_SIZE
42
+ __shared__ float Bs[16][16];
43
+
44
+ float accumulator = 0.0f;
45
+ int tx = threadIdx.x;
46
+ int ty = threadIdx.y;
47
+
48
+ // Bounds checking
49
+ if (start_m + tx < M && start_n + ty < N)
50
+ {
51
+ // K-dimension loop
52
+ for (int ki = 0; ki < k_tiles; ki++)
53
+ {
54
+ int k_start = ki * BLOCK_SIZE_K;
55
+
56
+ // Load tiles into shared memory
57
+ if (k_start + tx < K && start_m + ty < M)
58
+ {
59
+ As[ty][tx] = a_ptr[(start_m + ty) * stride_am + (k_start + tx) * stride_ak];
60
+ }
61
+ else
62
+ {
63
+ As[ty][tx] = 0.0f;
64
+ }
65
+
66
+ if (k_start + ty < K && start_n + tx < N)
67
+ {
68
+ Bs[ty][tx] = b_ptr[(k_start + ty) * stride_bk + (start_n + tx) * stride_bn];
69
+ }
70
+ else
71
+ {
72
+ Bs[ty][tx] = 0.0f;
73
+ }
74
+
75
+ __syncthreads();
76
+
77
+ // Compute partial dot product
78
+ for (int k = 0; k < min(BLOCK_SIZE_K, K - k_start); k++)
79
+ {
80
+ accumulator += As[ty][k] * Bs[k][tx];
81
+ }
82
+
83
+ __syncthreads();
84
+ }
85
+
86
+ // Add bias if present
87
+ if (HAS_BIAS && bias_ptr != nullptr)
88
+ {
89
+ accumulator += bias_ptr[start_n + tx];
90
+ }
91
+
92
+ // Store result
93
+ c_ptr[(start_m + ty) * stride_cm + (start_n + tx) * stride_cn] = accumulator;
94
+ }
95
+ }
96
+ }
97
+
98
+ // Log softmax kernel
99
+ __global__ void log_softmax_kernel(
100
+ const float *input_ptr,
101
+ float *output_ptr,
102
+ int input_row_stride,
103
+ int output_row_stride,
104
+ int n_cols,
105
+ int BLOCK_SIZE)
106
+ {
107
+ int row_idx = blockIdx.x;
108
+ int tid = threadIdx.x;
109
+
110
+ // Find maximum value in the row for numerical stability
111
+ __shared__ float max_val;
112
+ __shared__ float sum_exp;
113
+
114
+ if (tid == 0)
115
+ {
116
+ max_val = -INFINITY;
117
+ sum_exp = 0.0f;
118
+ }
119
+ __syncthreads();
120
+
121
+ // Reduction to find max
122
+ float thread_max = -INFINITY;
123
+ for (int col = tid; col < n_cols; col += blockDim.x)
124
+ {
125
+ float val = input_ptr[row_idx * input_row_stride + col];
126
+ thread_max = fmaxf(thread_max, val);
127
+ }
128
+
129
+ // Block-wide reduction for max
130
+ __shared__ float sdata[256];
131
+ sdata[tid] = thread_max;
132
+ __syncthreads();
133
+
134
+ for (int s = blockDim.x / 2; s > 0; s >>= 1)
135
+ {
136
+ if (tid < s)
137
+ {
138
+ sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
139
+ }
140
+ __syncthreads();
141
+ }
142
+
143
+ if (tid == 0)
144
+ {
145
+ max_val = sdata[0];
146
+ }
147
+ __syncthreads();
148
+
149
+ // Compute sum of exp(x - max_val)
150
+ float thread_sum = 0.0f;
151
+ for (int col = tid; col < n_cols; col += blockDim.x)
152
+ {
153
+ float val = input_ptr[row_idx * input_row_stride + col];
154
+ thread_sum += expf(val - max_val);
155
+ }
156
+
157
+ // Block-wide reduction for sum
158
+ sdata[tid] = thread_sum;
159
+ __syncthreads();
160
+
161
+ for (int s = blockDim.x / 2; s > 0; s >>= 1)
162
+ {
163
+ if (tid < s)
164
+ {
165
+ sdata[tid] += sdata[tid + s];
166
+ }
167
+ __syncthreads();
168
+ }
169
+
170
+ if (tid == 0)
171
+ {
172
+ sum_exp = sdata[0];
173
+ }
174
+ __syncthreads();
175
+
176
+ float log_sum_exp = logf(sum_exp);
177
+
178
+ // Compute final log_softmax values
179
+ for (int col = tid; col < n_cols; col += blockDim.x)
180
+ {
181
+ float val = input_ptr[row_idx * input_row_stride + col];
182
+ output_ptr[row_idx * output_row_stride + col] = val - max_val - log_sum_exp;
183
+ }
184
+ }
185
+
186
+ // Mean reduction kernel
187
+ __global__ void mean_kernel(
188
+ const float *input_ptr,
189
+ float *output_ptr,
190
+ int input_stride0, int input_stride1, int input_stride2,
191
+ int output_stride0, int output_stride1,
192
+ int M, int N, int K,
193
+ int BLOCK_SIZE)
194
+ {
195
+ int pid = blockIdx.x * blockDim.x + threadIdx.x;
196
+
197
+ if (pid >= M * K)
198
+ return;
199
+
200
+ int m_idx = pid / K;
201
+ int k_idx = pid % K;
202
+
203
+ float acc = 0.0f;
204
+ for (int n = 0; n < N; n++)
205
+ {
206
+ int input_idx = m_idx * input_stride0 + n * input_stride1 + k_idx * input_stride2;
207
+ acc += input_ptr[input_idx];
208
+ }
209
+
210
+ float mean_val = acc / N;
211
+ int output_idx = m_idx * output_stride0 + k_idx * output_stride1;
212
+ output_ptr[output_idx] = mean_val;
213
+ }
214
+
215
+ // Host functions that launch the kernels
216
+ void matmul_persistent_cuda(
217
+ torch::Tensor const &a,
218
+ torch::Tensor const &b,
219
+ torch::Tensor &c,
220
+ torch::Tensor const &bias)
221
+ {
222
+ const int M = a.size(0);
223
+ const int K = a.size(1);
224
+ const int N = b.size(1);
225
+
226
+ // Get device properties
227
+ cudaDeviceProp prop;
228
+ cudaGetDeviceProperties(&prop, 0);
229
+ const int NUM_SMS = prop.multiProcessorCount;
230
+
231
+ // Block sizes
232
+ const int BLOCK_SIZE_M = 128;
233
+ const int BLOCK_SIZE_N = 128;
234
+ const int BLOCK_SIZE_K = 64;
235
+ const int GROUP_SIZE_M = 8;
236
+
237
+ // Grid configuration
238
+ const int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
239
+ const int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
240
+ const int grid_size = min(NUM_SMS, num_pid_m * num_pid_n);
241
+
242
+ dim3 block(16, 16);
243
+ dim3 grid_dim(grid_size);
244
+
245
+ matmul_kernel_persistent<<<grid_dim, block>>>(
246
+ a.data_ptr<float>(),
247
+ b.data_ptr<float>(),
248
+ c.data_ptr<float>(),
249
+ bias.defined() ? bias.data_ptr<float>() : nullptr,
250
+ M, N, K,
251
+ a.stride(0), a.stride(1),
252
+ b.stride(0), b.stride(1),
253
+ c.stride(0), c.stride(1),
254
+ BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
255
+ GROUP_SIZE_M, NUM_SMS,
256
+ bias.defined());
257
+ }
258
+
259
+ void log_softmax_cuda(
260
+ torch::Tensor const &input,
261
+ torch::Tensor &output)
262
+ {
263
+ const auto original_shape = input.sizes();
264
+ auto input_2d = input.reshape({-1, input.size(-1)}).contiguous();
265
+ auto output_2d = output.reshape({-1, output.size(-1)});
266
+
267
+ const int n_rows = input_2d.size(0);
268
+ const int n_cols = input_2d.size(1);
269
+
270
+ const int BLOCK_SIZE = 256;
271
+
272
+ log_softmax_kernel<<<n_rows, BLOCK_SIZE>>>(
273
+ input_2d.data_ptr<float>(),
274
+ output_2d.data_ptr<float>(),
275
+ input_2d.stride(0),
276
+ output_2d.stride(0),
277
+ n_cols,
278
+ BLOCK_SIZE);
279
+ }
280
+
281
+ void mean_dim_cuda(
282
+ torch::Tensor const &input,
283
+ torch::Tensor &output,
284
+ int dim)
285
+ {
286
+ auto shape = input.sizes().vec();
287
+
288
+ int M = 1;
289
+ for (int i = 0; i < dim; i++)
290
+ {
291
+ M *= shape[i];
292
+ }
293
+
294
+ int N = shape[dim];
295
+
296
+ int K = 1;
297
+ for (int i = dim + 1; i < shape.size(); i++)
298
+ {
299
+ K *= shape[i];
300
+ }
301
+
302
+ auto input_3d = input.reshape({M, N, K});
303
+ auto output_2d = output.reshape({M, K});
304
+
305
+ const int BLOCK_SIZE = 256;
306
+ const int grid_size = (M * K + BLOCK_SIZE - 1) / BLOCK_SIZE;
307
+
308
+ mean_kernel<<<grid_size, BLOCK_SIZE>>>(
309
+ input_3d.data_ptr<float>(),
310
+ output_2d.data_ptr<float>(),
311
+ input_3d.stride(0), input_3d.stride(1), input_3d.stride(2),
312
+ output_2d.stride(0), output_2d.stride(1),
313
+ M, N, K,
314
+ BLOCK_SIZE);
315
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for batch_invariant kernel";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
torch-ext/batch_invariant/__init__.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ._ops import ops
3
+
4
+
5
+ def matmul_persistent(
6
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor = None
7
+ ) -> torch.Tensor:
8
+ """
9
+ Persistent matrix multiplication with optional bias.
10
+
11
+ Args:
12
+ a: Input tensor of shape (M, K)
13
+ b: Input tensor of shape (K, N)
14
+ bias: Optional bias tensor of shape (N,)
15
+
16
+ Returns:
17
+ Output tensor of shape (M, N)
18
+ """
19
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
20
+ assert a.dtype == b.dtype, "Incompatible dtypes"
21
+ assert bias is None or bias.dim() == 1, "Bias must be 1D"
22
+
23
+ M, K = a.shape
24
+ K, N = b.shape
25
+
26
+ c = torch.empty((M, N), device=a.device, dtype=a.dtype)
27
+ ops.matmul_persistent(a, b, c, bias)
28
+
29
+ return c
30
+
31
+
32
+ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
33
+ """
34
+ Compute log_softmax using custom CUDA kernel.
35
+
36
+ Args:
37
+ input: Input tensor
38
+ dim: Dimension along which to compute log_softmax (only -1 supported)
39
+
40
+ Returns:
41
+ Tensor with log_softmax applied
42
+ """
43
+ if dim != -1 and dim != input.ndim - 1:
44
+ raise ValueError(
45
+ "This implementation only supports log_softmax along the last dimension"
46
+ )
47
+
48
+ output = torch.empty_like(input)
49
+ ops.log_softmax(input, output)
50
+
51
+ return output
52
+
53
+
54
+ def mean_dim(
55
+ input: torch.Tensor, dim: int, keepdim: bool = False, dtype: torch.dtype = None
56
+ ) -> torch.Tensor:
57
+ """
58
+ Compute mean along a single dimension.
59
+
60
+ Args:
61
+ input: Input tensor
62
+ dim: Single dimension along which to compute mean
63
+ keepdim: Whether to keep the reduced dimension
64
+ dtype: Output dtype
65
+
66
+ Returns:
67
+ Tensor with mean values along specified dimension
68
+ """
69
+ assert input.is_cuda, "Input must be a CUDA tensor"
70
+ assert -input.ndim <= dim < input.ndim, f"Invalid dimension {dim}"
71
+
72
+ if dim < 0:
73
+ dim = dim + input.ndim
74
+
75
+ if dtype is None:
76
+ if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
77
+ dtype = torch.float32
78
+ else:
79
+ dtype = input.dtype
80
+
81
+ if input.dtype != dtype:
82
+ input = input.to(dtype)
83
+
84
+ shape = list(input.shape)
85
+
86
+ if keepdim:
87
+ output_shape = shape.copy()
88
+ output_shape[dim] = 1
89
+ else:
90
+ output_shape = shape[:dim] + shape[dim + 1 :]
91
+
92
+ output = torch.empty(output_shape, dtype=dtype, device=input.device)
93
+ ops.mean_dim(input, output, dim)
94
+
95
+ return output
96
+
97
+
98
+ # Batch invariant mode functionality (if you still want the mode switching)
99
+ def mm_batch_invariant(a, b):
100
+ return matmul_persistent(a, b)
101
+
102
+
103
+ def addmm_batch_invariant(bias, a, b):
104
+ return matmul_persistent(a, b, bias=bias)
105
+
106
+
107
+ def _log_softmax_batch_invariant(input, dim, _half_to_float):
108
+ assert not _half_to_float, "not implemented"
109
+ return log_softmax(input, dim=dim)
110
+
111
+
112
+ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype = None):
113
+ if len(dim) == 1:
114
+ return mean_dim(input, dim[0], keepdim=keepdim, dtype=dtype)
115
+ else:
116
+ # Multi-dimensional mean fallback
117
+ n_elems = 1
118
+ for d in dim:
119
+ n_elems *= input.shape[d]
120
+ return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include "torch_binding.h"
3
+
4
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops)
5
+ {
6
+ ops.def("matmul_persistent(Tensor a, Tensor b, Tensor! c, Tensor? bias) -> ()");
7
+ ops.def("log_softmax(Tensor input, Tensor! output) -> ()");
8
+ ops.def("mean_dim(Tensor input, Tensor! output, int dim) -> ()");
9
+
10
+ ops.impl("matmul_persistent", torch::kCUDA, &matmul_persistent_cuda);
11
+ ops.impl("log_softmax", torch::kCUDA, &log_softmax_cuda);
12
+ ops.impl("mean_dim", torch::kCUDA, &mean_dim_cuda);
13
+ }
14
+
15
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <torch/extension.h>
3
+
4
+ void matmul_persistent_cuda(
5
+ torch::Tensor const &a,
6
+ torch::Tensor const &b,
7
+ torch::Tensor &c,
8
+ torch::Tensor const &bias);
9
+
10
+ void log_softmax_cuda(
11
+ torch::Tensor const &input,
12
+ torch::Tensor &output);
13
+
14
+ void mean_dim_cuda(
15
+ torch::Tensor const &input,
16
+ torch::Tensor &output,
17
+ int dim);