rotary_cuda.cu 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #include <torch/python.h>
  5. #include <ATen/native/TensorIterator.h>
  6. #include <ATen/native/cuda/Loops.cuh>
  7. void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
  8. const torch::Tensor cos, const torch::Tensor sin,
  9. torch::Tensor out1, torch::Tensor out2,
  10. const bool conj) {
  11. auto iter = at::TensorIteratorConfig()
  12. .add_output(out1)
  13. .add_output(out2)
  14. .add_input(x1)
  15. .add_input(x2)
  16. .add_input(cos)
  17. .add_input(sin)
  18. .check_all_same_dtype(false)
  19. .promote_inputs_to_common_dtype(false)
  20. .build();
  21. if (!conj) {
  22. AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
  23. at::native::gpu_kernel_multiple_outputs(
  24. iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
  25. scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
  26. scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin);
  27. scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos);
  28. return {out1, out2};
  29. });
  30. });
  31. } else {
  32. AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] {
  33. at::native::gpu_kernel_multiple_outputs(
  34. iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos,
  35. scalar_t sin) -> thrust::tuple<scalar_t, scalar_t> {
  36. scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin);
  37. scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos);
  38. return {out1, out2};
  39. });
  40. });
  41. }
  42. }