File size: 879 Bytes
44e9845
 
 
 
f517c97
 
 
 
 
 
 
 
f3b99fb
f517c97
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
#pragma once

#include <torch/torch.h>

void poly_norm(torch::Tensor &out, const torch::Tensor &input,
               const torch::Tensor &weights, const torch::Tensor &bias,
               double eps);
void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
                        torch::Tensor &bias_grad,
                        const torch::Tensor &output_grad,
                        const torch::Tensor &input, const torch::Tensor &weight,
                        double eps);

void rms_norm(torch::Tensor &out, const torch::Tensor &input,
              const torch::Tensor &weights, double eps);
void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad,
                       const torch::Tensor &output_grad,
                       const torch::Tensor &input, const torch::Tensor &weight,
                       double eps);