const float EPS = 1e-5f; | |
// CPU implementation of RMSNorm | |
torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) { | |
int B = x.size(0), S = x.size(1), H = x.size(2); | |
auto out = torch::empty_like(x); | |
auto x_accessor = x.accessor<float, 3>(); | |
auto gamma_accessor = gamma.accessor<float, 1>(); | |
auto out_accessor = out.accessor<float, 3>(); | |
// Process each row | |
for (int b = 0; b < B; ++b) { | |
for (int s = 0; s < S; ++s) { | |
// Calculate root mean square | |
float sum_sq = 0.0f; | |
for (int h = 0; h < H; ++h) { | |
float val = x_accessor[b][s][h]; | |
sum_sq += val * val; | |
} | |
float rms = std::sqrt(sum_sq / H + EPS); | |
// Normalize and scale | |
for (int h = 0; h < H; ++h) { | |
out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h]; | |
} | |
} | |
} | |
return out; | |
} | |
struct RmsnormFunctor { | |
const float* x; | |
const float* gamma; | |
float* out; | |
int hidden_dim; | |
RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_) | |
: x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {} | |
__device__ | |
void operator()(int row_idx) { | |
const float* row_x = x + row_idx * hidden_dim; | |
float* row_out = out + row_idx * hidden_dim; | |
float sum_sq = 0.0f; | |
for (int i = 0; i < hidden_dim; ++i) | |
sum_sq += row_x[i] * row_x[i]; | |
float rms = sqrtf(sum_sq / hidden_dim + EPS); | |
for (int i = 0; i < hidden_dim; ++i) | |
row_out[i] = (row_x[i] / rms) * gamma[i]; | |
} | |
}; | |
torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor gamma) { | |
int B = x.size(0), S = x.size(1), H = x.size(2); | |
int rows = B * S; | |
// Create output tensor with same shape as input | |
auto out = torch::empty_like(x); | |
const float* x_ptr = x.data_ptr<float>(); | |
const float* gamma_ptr = gamma.data_ptr<float>(); | |
float* out_ptr = out.data_ptr<float>(); | |
thrust::counting_iterator<int> iter(0); | |
thrust::for_each( | |
thrust::device, | |
iter, iter + rows, | |
RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H) | |
); | |
return out; | |
} | |
// int main() { | |
// int B = 2, S = 2, H = 4; | |
// // Create tensors directly on CPU first | |
// auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); | |
// // Initialize with CPU data | |
// torch::Tensor x_cpu = torch::tensor({ | |
// { | |
// {1.0f, 2.0f, 3.0f, 4.0f}, | |
// {5.0f, 6.0f, 7.0f, 8.0f} | |
// }, | |
// { | |
// {2.0f, 2.0f, 2.0f, 2.0f}, | |
// {9.0f, 10.0f, 11.0f, 12.0f} | |
// } | |
// }, options_cpu); | |
// torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu); | |
// // Run CPU version | |
// std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl; | |
// torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu); | |
// auto cpu_accessor = out_cpu.accessor<float, 3>(); | |
// for (int b = 0; b < B; ++b) { | |
// for (int s = 0; s < S; ++s) { | |
// std::cout << "Row " << (b * S + s) << ": "; | |
// for (int h = 0; h < H; ++h) { | |
// std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " "; | |
// } | |
// std::cout << "\n"; | |
// } | |
// } | |
// // Move tensors to CUDA for GPU version | |
// auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); | |
// torch::Tensor x_cuda = x_cpu.to(torch::kCUDA); | |
// torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA); | |
// // Call the CUDA kernel wrapper | |
// std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl; | |
// torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda); | |
// // Copy result back to CPU and print | |
// auto gpu_result_on_cpu = out_cuda.cpu(); | |
// auto gpu_accessor = gpu_result_on_cpu.accessor<float, 3>(); | |
// for (int b = 0; b < B; ++b) { | |
// for (int s = 0; s < S; ++s) { | |
// std::cout << "Row " << (b * S + s) << ": "; | |
// for (int h = 0; h < H; ++h) { | |
// std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " "; | |
// } | |
// std::cout << "\n"; | |
// } | |
// } | |
// // Check if results match | |
// std::cout << "\n===== COMPARISON =====" << std::endl; | |
// float max_diff = 0.0f; | |
// for (int b = 0; b < B; ++b) { | |
// for (int s = 0; s < S; ++s) { | |
// for (int h = 0; h < H; ++h) { | |
// float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]); | |
// max_diff = std::max(max_diff, diff); | |
// } | |
// } | |
// } | |
// std::cout << "Maximum difference between CPU and GPU results: " | |
// << std::scientific << max_diff << std::endl; | |
// std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl; | |
// return 0; | |
// } |