File size: 8,741 Bytes
e6010fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
#include <torch/extension.h>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cudnn.h>
#include <cmath>
// Persistent matrix multiplication kernel
__global__ void matmul_kernel_persistent(
const float *a_ptr,
const float *b_ptr,
float *c_ptr,
const float *bias_ptr,
int M, int N, int K,
int stride_am, int stride_ak,
int stride_bk, int stride_bn,
int stride_cm, int stride_cn,
int BLOCK_SIZE_M, int BLOCK_SIZE_N, int BLOCK_SIZE_K,
int GROUP_SIZE_M, int NUM_SMS,
bool HAS_BIAS)
{
int start_pid = blockIdx.x;
int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
int k_tiles = (K + BLOCK_SIZE_K - 1) / BLOCK_SIZE_K;
int num_tiles = num_pid_m * num_pid_n;
int num_pid_in_group = GROUP_SIZE_M * num_pid_n;
for (int tile_id = start_pid; tile_id < num_tiles; tile_id += NUM_SMS)
{
int group_id = tile_id / num_pid_in_group;
int first_pid_m = group_id * GROUP_SIZE_M;
int group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M);
int pid_m = first_pid_m + (tile_id % group_size_m);
int pid_n = (tile_id % num_pid_in_group) / group_size_m;
int start_m = pid_m * BLOCK_SIZE_M;
int start_n = pid_n * BLOCK_SIZE_N;
// Shared memory for tile computation
__shared__ float As[16][16]; // Adjust size based on BLOCK_SIZE
__shared__ float Bs[16][16];
float accumulator = 0.0f;
int tx = threadIdx.x;
int ty = threadIdx.y;
// Bounds checking
if (start_m + tx < M && start_n + ty < N)
{
// K-dimension loop
for (int ki = 0; ki < k_tiles; ki++)
{
int k_start = ki * BLOCK_SIZE_K;
// Load tiles into shared memory
if (k_start + tx < K && start_m + ty < M)
{
As[ty][tx] = a_ptr[(start_m + ty) * stride_am + (k_start + tx) * stride_ak];
}
else
{
As[ty][tx] = 0.0f;
}
if (k_start + ty < K && start_n + tx < N)
{
Bs[ty][tx] = b_ptr[(k_start + ty) * stride_bk + (start_n + tx) * stride_bn];
}
else
{
Bs[ty][tx] = 0.0f;
}
__syncthreads();
// Compute partial dot product
for (int k = 0; k < min(BLOCK_SIZE_K, K - k_start); k++)
{
accumulator += As[ty][k] * Bs[k][tx];
}
__syncthreads();
}
// Add bias if present
if (HAS_BIAS && bias_ptr != nullptr)
{
accumulator += bias_ptr[start_n + tx];
}
// Store result
c_ptr[(start_m + ty) * stride_cm + (start_n + tx) * stride_cn] = accumulator;
}
}
}
// Log softmax kernel
__global__ void log_softmax_kernel(
const float *input_ptr,
float *output_ptr,
int input_row_stride,
int output_row_stride,
int n_cols,
int BLOCK_SIZE)
{
int row_idx = blockIdx.x;
int tid = threadIdx.x;
// Find maximum value in the row for numerical stability
__shared__ float max_val;
__shared__ float sum_exp;
if (tid == 0)
{
max_val = -INFINITY;
sum_exp = 0.0f;
}
__syncthreads();
// Reduction to find max
float thread_max = -INFINITY;
for (int col = tid; col < n_cols; col += blockDim.x)
{
float val = input_ptr[row_idx * input_row_stride + col];
thread_max = fmaxf(thread_max, val);
}
// Block-wide reduction for max
__shared__ float sdata[256];
sdata[tid] = thread_max;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1)
{
if (tid < s)
{
sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
}
__syncthreads();
}
if (tid == 0)
{
max_val = sdata[0];
}
__syncthreads();
// Compute sum of exp(x - max_val)
float thread_sum = 0.0f;
for (int col = tid; col < n_cols; col += blockDim.x)
{
float val = input_ptr[row_idx * input_row_stride + col];
thread_sum += expf(val - max_val);
}
// Block-wide reduction for sum
sdata[tid] = thread_sum;
__syncthreads();
for (int s = blockDim.x / 2; s > 0; s >>= 1)
{
if (tid < s)
{
sdata[tid] += sdata[tid + s];
}
__syncthreads();
}
if (tid == 0)
{
sum_exp = sdata[0];
}
__syncthreads();
float log_sum_exp = logf(sum_exp);
// Compute final log_softmax values
for (int col = tid; col < n_cols; col += blockDim.x)
{
float val = input_ptr[row_idx * input_row_stride + col];
output_ptr[row_idx * output_row_stride + col] = val - max_val - log_sum_exp;
}
}
// Mean reduction kernel
__global__ void mean_kernel(
const float *input_ptr,
float *output_ptr,
int input_stride0, int input_stride1, int input_stride2,
int output_stride0, int output_stride1,
int M, int N, int K,
int BLOCK_SIZE)
{
int pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= M * K)
return;
int m_idx = pid / K;
int k_idx = pid % K;
float acc = 0.0f;
for (int n = 0; n < N; n++)
{
int input_idx = m_idx * input_stride0 + n * input_stride1 + k_idx * input_stride2;
acc += input_ptr[input_idx];
}
float mean_val = acc / N;
int output_idx = m_idx * output_stride0 + k_idx * output_stride1;
output_ptr[output_idx] = mean_val;
}
// Host functions that launch the kernels
void matmul_persistent_cuda(
torch::Tensor const &a,
torch::Tensor const &b,
torch::Tensor &c,
torch::Tensor const &bias)
{
const int M = a.size(0);
const int K = a.size(1);
const int N = b.size(1);
// Get device properties
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
const int NUM_SMS = prop.multiProcessorCount;
// Block sizes
const int BLOCK_SIZE_M = 128;
const int BLOCK_SIZE_N = 128;
const int BLOCK_SIZE_K = 64;
const int GROUP_SIZE_M = 8;
// Grid configuration
const int num_pid_m = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M;
const int num_pid_n = (N + BLOCK_SIZE_N - 1) / BLOCK_SIZE_N;
const int grid_size = min(NUM_SMS, num_pid_m * num_pid_n);
dim3 block(16, 16);
dim3 grid_dim(grid_size);
matmul_kernel_persistent<<<grid_dim, block>>>(
a.data_ptr<float>(),
b.data_ptr<float>(),
c.data_ptr<float>(),
bias.defined() ? bias.data_ptr<float>() : nullptr,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K,
GROUP_SIZE_M, NUM_SMS,
bias.defined());
}
void log_softmax_cuda(
torch::Tensor const &input,
torch::Tensor &output)
{
const auto original_shape = input.sizes();
auto input_2d = input.reshape({-1, input.size(-1)}).contiguous();
auto output_2d = output.reshape({-1, output.size(-1)});
const int n_rows = input_2d.size(0);
const int n_cols = input_2d.size(1);
const int BLOCK_SIZE = 256;
log_softmax_kernel<<<n_rows, BLOCK_SIZE>>>(
input_2d.data_ptr<float>(),
output_2d.data_ptr<float>(),
input_2d.stride(0),
output_2d.stride(0),
n_cols,
BLOCK_SIZE);
}
void mean_dim_cuda(
torch::Tensor const &input,
torch::Tensor &output,
int dim)
{
auto shape = input.sizes().vec();
int M = 1;
for (int i = 0; i < dim; i++)
{
M *= shape[i];
}
int N = shape[dim];
int K = 1;
for (int i = dim + 1; i < shape.size(); i++)
{
K *= shape[i];
}
auto input_3d = input.reshape({M, N, K});
auto output_2d = output.reshape({M, K});
const int BLOCK_SIZE = 256;
const int grid_size = (M * K + BLOCK_SIZE - 1) / BLOCK_SIZE;
mean_kernel<<<grid_size, BLOCK_SIZE>>>(
input_3d.data_ptr<float>(),
output_2d.data_ptr<float>(),
input_3d.stride(0), input_3d.stride(1), input_3d.stride(2),
output_2d.stride(0), output_2d.stride(1),
M, N, K,
BLOCK_SIZE);
} |