medmekk's picture
medmekk HF Staff
Update rmsnorm_kernel/rmsnorm.cu
a63f4de verified
#include <torch/all.h>
#include <thrust/execution_policy.h>
#include <thrust/for_each.h>
#include <thrust/iterator/counting_iterator.h>
#include <cmath>
#include <thrust/device_vector.h>
#include <thrust/copy.h>
#include <iostream>
#include <iomanip> // For formatting output
const float EPS = 1e-5f;
// CPU implementation of RMSNorm
torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) {
int B = x.size(0), S = x.size(1), H = x.size(2);
auto out = torch::empty_like(x);
auto x_accessor = x.accessor<float, 3>();
auto gamma_accessor = gamma.accessor<float, 1>();
auto out_accessor = out.accessor<float, 3>();
// Process each row
for (int b = 0; b < B; ++b) {
for (int s = 0; s < S; ++s) {
// Calculate root mean square
float sum_sq = 0.0f;
for (int h = 0; h < H; ++h) {
float val = x_accessor[b][s][h];
sum_sq += val * val;
}
float rms = std::sqrt(sum_sq / H + EPS);
// Normalize and scale
for (int h = 0; h < H; ++h) {
out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h];
}
}
}
return out;
}
struct RmsnormFunctor {
const float* x;
const float* gamma;
float* out;
int hidden_dim;
RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_)
: x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {}
__device__
void operator()(int row_idx) {
const float* row_x = x + row_idx * hidden_dim;
float* row_out = out + row_idx * hidden_dim;
float sum_sq = 0.0f;
for (int i = 0; i < hidden_dim; ++i)
sum_sq += row_x[i] * row_x[i];
float rms = sqrtf(sum_sq / hidden_dim + EPS);
for (int i = 0; i < hidden_dim; ++i)
row_out[i] = (row_x[i] / rms) * gamma[i];
}
};
torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor gamma) {
int B = x.size(0), S = x.size(1), H = x.size(2);
int rows = B * S;
// Create output tensor with same shape as input
auto out = torch::empty_like(x);
const float* x_ptr = x.data_ptr<float>();
const float* gamma_ptr = gamma.data_ptr<float>();
float* out_ptr = out.data_ptr<float>();
thrust::counting_iterator<int> iter(0);
thrust::for_each(
thrust::device,
iter, iter + rows,
RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H)
);
return out;
}
// int main() {
// int B = 2, S = 2, H = 4;
// // Create tensors directly on CPU first
// auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
// // Initialize with CPU data
// torch::Tensor x_cpu = torch::tensor({
// {
// {1.0f, 2.0f, 3.0f, 4.0f},
// {5.0f, 6.0f, 7.0f, 8.0f}
// },
// {
// {2.0f, 2.0f, 2.0f, 2.0f},
// {9.0f, 10.0f, 11.0f, 12.0f}
// }
// }, options_cpu);
// torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu);
// // Run CPU version
// std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl;
// torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu);
// auto cpu_accessor = out_cpu.accessor<float, 3>();
// for (int b = 0; b < B; ++b) {
// for (int s = 0; s < S; ++s) {
// std::cout << "Row " << (b * S + s) << ": ";
// for (int h = 0; h < H; ++h) {
// std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " ";
// }
// std::cout << "\n";
// }
// }
// // Move tensors to CUDA for GPU version
// auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// torch::Tensor x_cuda = x_cpu.to(torch::kCUDA);
// torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA);
// // Call the CUDA kernel wrapper
// std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl;
// torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda);
// // Copy result back to CPU and print
// auto gpu_result_on_cpu = out_cuda.cpu();
// auto gpu_accessor = gpu_result_on_cpu.accessor<float, 3>();
// for (int b = 0; b < B; ++b) {
// for (int s = 0; s < S; ++s) {
// std::cout << "Row " << (b * S + s) << ": ";
// for (int h = 0; h < H; ++h) {
// std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " ";
// }
// std::cout << "\n";
// }
// }
// // Check if results match
// std::cout << "\n===== COMPARISON =====" << std::endl;
// float max_diff = 0.0f;
// for (int b = 0; b < B; ++b) {
// for (int s = 0; s < S; ++s) {
// for (int h = 0; h < H; ++h) {
// float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]);
// max_diff = std::max(max_diff, diff);
// }
// }
// }
// std::cout << "Maximum difference between CPU and GPU results: "
// << std::scientific << max_diff << std::endl;
// std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl;
// return 0;
// }