TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
ops.def("rmsnorm_forward(Tensor! out, Tensor input, Tensor gamma) -> ()"); | |
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward); | |
} | |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
ops.def("rmsnorm_forward(Tensor! out, Tensor input, Tensor gamma) -> ()"); | |
ops.impl("rmsnorm_forward", torch::kCUDA, &rmsnorm_forward); | |
} | |
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |