void poly_norm(torch::Tensor &out, const torch::Tensor &input, | |
const torch::Tensor &weights, const torch::Tensor &bias, | |
double eps); | |
void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, | |
torch::Tensor &bias_grad, | |
const torch::Tensor &output_grad, | |
const torch::Tensor &input, const torch::Tensor &weight, | |
double eps); | |
void rms_norm(torch::Tensor &out, const torch::Tensor &input, | |
const torch::Tensor &weights, double eps); | |
void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, | |
const torch::Tensor &output_grad, | |
const torch::Tensor &input, const torch::Tensor &weight, | |
double eps); | |