#include "torch_binding.h" #include #include "registration.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // poly_norm ops.def("poly_norm(Tensor! out, Tensor input, Tensor weight, Tensor bias, " "float eps) -> ()"); ops.impl("poly_norm", torch::kCUDA, &poly_norm); ops.def("poly_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor! " "bias_grad, Tensor output_grad, Tensor input, Tensor weight, float " "eps) -> ()"); ops.impl("poly_norm_backward", torch::kCUDA, &poly_norm_backward); // rms_norm ops.def( "rms_norm(Tensor! out, Tensor input, Tensor weight, float eps) -> ()"); ops.impl("rms_norm", torch::kCUDA, &rms_norm); ops.def("rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, Tensor " "output_grad, Tensor input, Tensor weight, float eps) -> ()"); ops.impl("rms_norm_backward", torch::kCUDA, &rms_norm_backward); // fused_mul_poly_norm ops.def("fused_mul_poly_norm(Tensor! out, Tensor input, Tensor mul, Tensor " "weight, Tensor bias, " "float eps) -> ()"); ops.impl("fused_mul_poly_norm", torch::kCUDA, &fused_mul_poly_norm); ops.def("fused_mul_poly_norm_backward(Tensor! input_grad, Tensor! mul_grad, " "Tensor! weight_grad, Tensor! " "bias_grad, Tensor output_grad, Tensor input, Tensor mul, Tensor " "weight, Tensor " "bias, float eps) -> ()"); ops.impl("fused_mul_poly_norm_backward", torch::kCUDA, &fused_mul_poly_norm_backward); // fused_add_rms_norm ops.def( "fused_add_rms_norm(Tensor! out, Tensor! add_out, Tensor input, Tensor " "residual, Tensor " "weight, float eps) -> ()"); ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm); ops.def( "fused_add_rms_norm_backward(Tensor! input_grad, Tensor! weight_grad, " "Tensor " "output_grad, Tensor add_output_grad, Tensor input, Tensor weight, float " "eps) -> ()"); ops.impl("fused_add_rms_norm_backward", torch::kCUDA, &fused_add_rms_norm_backward); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME)