layernorm.cpp 292 B

1234567891011121314
  1. #include <torch/extension.h>
  2. void rms_norm(
  3. torch::Tensor& out,
  4. torch::Tensor& input,
  5. torch::Tensor& weight,
  6. float epsilon);
  7. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  8. m.def(
  9. "rms_norm",
  10. &rms_norm,
  11. "Apply Root Mean Square (RMS) Normalization to the input tensor.");
  12. }