File size: 879 Bytes
44e9845 f517c97 f3b99fb f517c97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
#pragma once
#include <torch/torch.h>
void poly_norm(torch::Tensor &out, const torch::Tensor &input,
const torch::Tensor &weights, const torch::Tensor &bias,
double eps);
void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
torch::Tensor &bias_grad,
const torch::Tensor &output_grad,
const torch::Tensor &input, const torch::Tensor &weight,
double eps);
void rms_norm(torch::Tensor &out, const torch::Tensor &input,
const torch::Tensor &weights, double eps);
void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
const torch::Tensor &output_grad,
const torch::Tensor &input, const torch::Tensor &weight,
double eps);
|