File size: 8,741 Bytes
e6010fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cudnn.h>
#include <cmath>

// Persistent matrix multiplication kernel
__global__ void matmul_kernel_persistent(
    const float *a_ptr,
    const float *b_ptr,
    float *c_ptr,
    const float *bias_ptr,
    int M, int N, int K,
    int stride_am, int stride_ak,
    int stride_bk, int stride_bn,
    int stride_cm, int stride_cn,
    int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K,
    int GROUP_SIZE_M, int NUM_SMS,
    bool HAS_BIAS)
{
    int start_pid = blockIdx.x;
    int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
    int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
    int k_tiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
    int num_tiles = num_pid_m * num_pid_n;

    int num_pid_in_group = GROUP_SIZE_M * num_pid_n;

    for (int tile_id = start_pid; tile_id < num_tiles; tile_id += NUM_SMS)
    {
        int group_id = tile_id / num_pid_in_group;
        int first_pid_m = group_id * GROUP_SIZE_M;
        int group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M);
        int pid_m = first_pid_m + (tile_id % group_size_m);
        int pid_n = (tile_id % num_pid_in_group) / group_size_m;

        int start_m = pid_m * BLOCK_SIZE_M;
        int start_n = pid_n * BLOCK_SIZE_N;

        // Shared memory for tile computation
        __shared__ float As[16][16]; // Adjust size based on BLOCK_SIZE
        __shared__ float Bs[16][16];

        float accumulator = 0.0f;
        int tx = threadIdx.x;
        int ty = threadIdx.y;

        // Bounds checking
        if (start_m + tx < M && start_n + ty < N)
        {
            // K-dimension loop
            for (int ki = 0; ki < k_tiles; ki++)
            {
                int k_start = ki * BLOCK_SIZE_K;

                // Load tiles into shared memory
                if (k_start + tx < K && start_m + ty < M)
                {
                    As[ty][tx] = a_ptr[(start_m + ty) * stride_am + (k_start + tx) * stride_ak];
                }
                else
                {
                    As[ty][tx] = 0.0f;
                }

                if (k_start + ty < K && start_n + tx < N)
                {
                    Bs[ty][tx] = b_ptr[(k_start + ty) * stride_bk + (start_n + tx) * stride_bn];
                }
                else
                {
                    Bs[ty][tx] = 0.0f;
                }

                __syncthreads();

                // Compute partial dot product
                for (int k = 0; k < min(BLOCK_SIZE_K, K - k_start); k++)
                {
                    accumulator += As[ty][k] * Bs[k][tx];
                }

                __syncthreads();
            }

            // Add bias if present
            if (HAS_BIAS && bias_ptr != nullptr)
            {
                accumulator += bias_ptr[start_n + tx];
            }

            // Store result
            c_ptr[(start_m + ty) * stride_cm + (start_n + tx) * stride_cn] = accumulator;
        }
    }
}

// Log softmax kernel
__global__ void log_softmax_kernel(
    const float *input_ptr,
    float *output_ptr,
    int input_row_stride,
    int output_row_stride,
    int n_cols,
    int BLOCK_SIZE)
{
    int row_idx = blockIdx.x;
    int tid = threadIdx.x;

    // Find maximum value in the row for numerical stability
    __shared__ float max_val;
    __shared__ float sum_exp;

    if (tid == 0)
    {
        max_val = -INFINITY;
        sum_exp = 0.0f;
    }
    __syncthreads();

    // Reduction to find max
    float thread_max = -INFINITY;
    for (int col = tid; col < n_cols; col += blockDim.x)
    {
        float val = input_ptr[row_idx * input_row_stride + col];
        thread_max = fmaxf(thread_max, val);
    }

    // Block-wide reduction for max
    __shared__ float sdata[256];
    sdata[tid] = thread_max;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1)
    {
        if (tid < s)
        {
            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
        }
        __syncthreads();
    }

    if (tid == 0)
    {
        max_val = sdata[0];
    }
    __syncthreads();

    // Compute sum of exp(x - max_val)
    float thread_sum = 0.0f;
    for (int col = tid; col < n_cols; col += blockDim.x)
    {
        float val = input_ptr[row_idx * input_row_stride + col];
        thread_sum += expf(val - max_val);
    }

    // Block-wide reduction for sum
    sdata[tid] = thread_sum;
    __syncthreads();

    for (int s = blockDim.x / 2; s > 0; s >>= 1)
    {
        if (tid < s)
        {
            sdata[tid] += sdata[tid + s];
        }
        __syncthreads();
    }

    if (tid == 0)
    {
        sum_exp = sdata[0];
    }
    __syncthreads();

    float log_sum_exp = logf(sum_exp);

    // Compute final log_softmax values
    for (int col = tid; col < n_cols; col += blockDim.x)
    {
        float val = input_ptr[row_idx * input_row_stride + col];
        output_ptr[row_idx * output_row_stride + col] = val - max_val - log_sum_exp;
    }
}

// Mean reduction kernel
__global__ void mean_kernel(
    const float *input_ptr,
    float *output_ptr,
    int input_stride0, int input_stride1, int input_stride2,
    int output_stride0, int output_stride1,
    int M, int N, int K,
    int BLOCK_SIZE)
{
    int pid = blockIdx.x * blockDim.x + threadIdx.x;

    if (pid >= M * K)
        return;

    int m_idx = pid / K;
    int k_idx = pid % K;

    float acc = 0.0f;
    for (int n = 0; n < N; n++)
    {
        int input_idx = m_idx * input_stride0 + n * input_stride1 + k_idx * input_stride2;
        acc += input_ptr[input_idx];
    }

    float mean_val = acc / N;
    int output_idx = m_idx * output_stride0 + k_idx * output_stride1;
    output_ptr[output_idx] = mean_val;
}

// Host functions that launch the kernels
void matmul_persistent_cuda(
    torch::Tensor const &a,
    torch::Tensor const &b,
    torch::Tensor &c,
    torch::Tensor const &bias)
{
    const int M = a.size(0);
    const int K = a.size(1);
    const int N = b.size(1);

    // Get device properties
    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    const int NUM_SMS = prop.multiProcessorCount;

    // Block sizes
    const int BLOCK_SIZE_M = 128;
    const int BLOCK_SIZE_N = 128;
    const int BLOCK_SIZE_K = 64;
    const int GROUP_SIZE_M = 8;

    // Grid configuration
    const int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
    const int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
    const int grid_size = min(NUM_SMS, num_pid_m * num_pid_n);

    dim3 block(16, 16);
    dim3 grid_dim(grid_size);

    matmul_kernel_persistent<<<grid_dim, block>>>(
        a.data_ptr<float>(),
        b.data_ptr<float>(),
        c.data_ptr<float>(),
        bias.defined() ? bias.data_ptr<float>() : nullptr,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
        GROUP_SIZE_M, NUM_SMS,
        bias.defined());
}

void log_softmax_cuda(
    torch::Tensor const &input,
    torch::Tensor &output)
{
    const auto original_shape = input.sizes();
    auto input_2d = input.reshape({-1, input.size(-1)}).contiguous();
    auto output_2d = output.reshape({-1, output.size(-1)});

    const int n_rows = input_2d.size(0);
    const int n_cols = input_2d.size(1);

    const int BLOCK_SIZE = 256;

    log_softmax_kernel<<<n_rows, BLOCK_SIZE>>>(
        input_2d.data_ptr<float>(),
        output_2d.data_ptr<float>(),
        input_2d.stride(0),
        output_2d.stride(0),
        n_cols,
        BLOCK_SIZE);
}

void mean_dim_cuda(
    torch::Tensor const &input,
    torch::Tensor &output,
    int dim)
{
    auto shape = input.sizes().vec();

    int M = 1;
    for (int i = 0; i < dim; i++)
    {
        M *= shape[i];
    }

    int N = shape[dim];

    int K = 1;
    for (int i = dim + 1; i < shape.size(); i++)
    {
        K *= shape[i];
    }

    auto input_3d = input.reshape({M, N, K});
    auto output_2d = output.reshape({M, K});

    const int BLOCK_SIZE = 256;
    const int grid_size = (M * K + BLOCK_SIZE - 1) / BLOCK_SIZE;

    mean_kernel<<<grid_size, BLOCK_SIZE>>>(
        input_3d.data_ptr<float>(),
        output_2d.data_ptr<float>(),
        input_3d.stride(0), input_3d.stride(1), input_3d.stride(2),
        output_2d.stride(0), output_2d.stride(1),
        M, N, K,
        BLOCK_SIZE);
}