File size: 375 Bytes
e6010fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#pragma once
#include <torch/extension.h>

void matmul_persistent_cuda(

    torch::Tensor const &a,

    torch::Tensor const &b,

    torch::Tensor &c,

    torch::Tensor const &bias);

void log_softmax_cuda(

    torch::Tensor const &input,

    torch::Tensor &output);

void mean_dim_cuda(

    torch::Tensor const &input,

    torch::Tensor &output,

    int dim);