rotary.cpp 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #include <torch/extension.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
  7. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  8. void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
  9. const torch::Tensor cos, const torch::Tensor sin,
  10. torch::Tensor out1, torch::Tensor out2,
  11. const bool conj);
  12. void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
  13. const torch::Tensor cos, const torch::Tensor sin,
  14. torch::Tensor out1, torch::Tensor out2,
  15. const bool conj) {
  16. CHECK_DEVICE(x1); CHECK_DEVICE(x2);
  17. CHECK_DEVICE(cos); CHECK_DEVICE(sin);
  18. CHECK_DEVICE(out1); CHECK_DEVICE(out1);
  19. TORCH_CHECK(x1.dtype() == x2.dtype());
  20. TORCH_CHECK(cos.dtype() == sin.dtype());
  21. TORCH_CHECK(out1.dtype() == out2.dtype());
  22. TORCH_CHECK(x1.dtype() == cos.dtype());
  23. TORCH_CHECK(x1.dtype() == out1.dtype());
  24. TORCH_CHECK(x1.sizes() == x2.sizes());
  25. TORCH_CHECK(cos.sizes() == sin.sizes());
  26. TORCH_CHECK(out1.sizes() == out2.sizes());
  27. // Otherwise the kernel will be launched from cuda:0 device
  28. at::cuda::CUDAGuard device_guard{x1.device()};
  29. apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
  30. }
  31. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  32. m.def("apply_rotary", &apply_rotary, "Apply rotary embedding");
  33. }