activation.cpp 526 B

12345678910111213141516171819202122232425262728
  1. #include <torch/extension.h>
  2. void silu_and_mul(
  3. torch::Tensor& out,
  4. torch::Tensor& input);
  5. void gelu_new(
  6. torch::Tensor& out,
  7. torch::Tensor& input);
  8. void gelu_fast(
  9. torch::Tensor& out,
  10. torch::Tensor& input);
  11. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  12. m.def(
  13. "silu_and_mul",
  14. &silu_and_mul,
  15. "Activation function used in SwiGLU.");
  16. m.def(
  17. "gelu_new",
  18. &gelu_new,
  19. "GELU implementation used in GPT-2.");
  20. m.def(
  21. "gelu_fast",
  22. &gelu_fast,
  23. "Approximate GELU implementation.");
  24. }