#include #include #include #include #include #include #include #include #include // For formatting output const float EPS = 1e-5f; // CPU implementation of RMSNorm torch::Tensor rmsnorm_forward_cpu(torch::Tensor x, torch::Tensor gamma) { int B = x.size(0), S = x.size(1), H = x.size(2); auto out = torch::empty_like(x); auto x_accessor = x.accessor(); auto gamma_accessor = gamma.accessor(); auto out_accessor = out.accessor(); // Process each row for (int b = 0; b < B; ++b) { for (int s = 0; s < S; ++s) { // Calculate root mean square float sum_sq = 0.0f; for (int h = 0; h < H; ++h) { float val = x_accessor[b][s][h]; sum_sq += val * val; } float rms = std::sqrt(sum_sq / H + EPS); // Normalize and scale for (int h = 0; h < H; ++h) { out_accessor[b][s][h] = (x_accessor[b][s][h] / rms) * gamma_accessor[h]; } } } return out; } struct RmsnormFunctor { const float* x; const float* gamma; float* out; int hidden_dim; RmsnormFunctor(const float* x_, const float* gamma_, float* out_, int h_) : x(x_), gamma(gamma_), out(out_), hidden_dim(h_) {} __device__ void operator()(int row_idx) { const float* row_x = x + row_idx * hidden_dim; float* row_out = out + row_idx * hidden_dim; float sum_sq = 0.0f; for (int i = 0; i < hidden_dim; ++i) sum_sq += row_x[i] * row_x[i]; float rms = sqrtf(sum_sq / hidden_dim + EPS); for (int i = 0; i < hidden_dim; ++i) row_out[i] = (row_x[i] / rms) * gamma[i]; } }; torch::Tensor rmsnorm_forward(torch::Tensor x, torch::Tensor gamma) { int B = x.size(0), S = x.size(1), H = x.size(2); int rows = B * S; // Create output tensor with same shape as input auto out = torch::empty_like(x); const float* x_ptr = x.data_ptr(); const float* gamma_ptr = gamma.data_ptr(); float* out_ptr = out.data_ptr(); thrust::counting_iterator iter(0); thrust::for_each( thrust::device, iter, iter + rows, RmsnormFunctor(x_ptr, gamma_ptr, out_ptr, H) ); return out; } // int main() { // int B = 2, S = 2, H = 4; // // Create tensors directly on CPU first // auto options_cpu = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU); // // Initialize with CPU data // torch::Tensor x_cpu = torch::tensor({ // { // {1.0f, 2.0f, 3.0f, 4.0f}, // {5.0f, 6.0f, 7.0f, 8.0f} // }, // { // {2.0f, 2.0f, 2.0f, 2.0f}, // {9.0f, 10.0f, 11.0f, 12.0f} // } // }, options_cpu); // torch::Tensor gamma_cpu = torch::tensor({1.0f, 1.0f, 1.0f, 1.0f}, options_cpu); // // Run CPU version // std::cout << "===== CPU IMPLEMENTATION RESULTS =====" << std::endl; // torch::Tensor out_cpu = rmsnorm_forward_cpu(x_cpu, gamma_cpu); // auto cpu_accessor = out_cpu.accessor(); // for (int b = 0; b < B; ++b) { // for (int s = 0; s < S; ++s) { // std::cout << "Row " << (b * S + s) << ": "; // for (int h = 0; h < H; ++h) { // std::cout << std::fixed << std::setprecision(6) << cpu_accessor[b][s][h] << " "; // } // std::cout << "\n"; // } // } // // Move tensors to CUDA for GPU version // auto cuda_options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // torch::Tensor x_cuda = x_cpu.to(torch::kCUDA); // torch::Tensor gamma_cuda = gamma_cpu.to(torch::kCUDA); // // Call the CUDA kernel wrapper // std::cout << "\n===== GPU IMPLEMENTATION RESULTS =====" << std::endl; // torch::Tensor out_cuda = rmsnorm_forward(x_cuda, gamma_cuda); // // Copy result back to CPU and print // auto gpu_result_on_cpu = out_cuda.cpu(); // auto gpu_accessor = gpu_result_on_cpu.accessor(); // for (int b = 0; b < B; ++b) { // for (int s = 0; s < S; ++s) { // std::cout << "Row " << (b * S + s) << ": "; // for (int h = 0; h < H; ++h) { // std::cout << std::fixed << std::setprecision(6) << gpu_accessor[b][s][h] << " "; // } // std::cout << "\n"; // } // } // // Check if results match // std::cout << "\n===== COMPARISON =====" << std::endl; // float max_diff = 0.0f; // for (int b = 0; b < B; ++b) { // for (int s = 0; s < S; ++s) { // for (int h = 0; h < H; ++h) { // float diff = std::abs(cpu_accessor[b][s][h] - gpu_accessor[b][s][h]); // max_diff = std::max(max_diff, diff); // } // } // } // std::cout << "Maximum difference between CPU and GPU results: " // << std::scientific << max_diff << std::endl; // std::cout << (max_diff < 1e-5 ? "PASSED: Results match!" : "FAILED: Results don't match!") << std::endl; // return 0; // }