activation_kernels.cu 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. namespace aphrodite {
  8. // Activation and gating kernel template.
  9. template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  10. __global__ void act_and_mul_kernel(
  11. scalar_t* __restrict__ out, // [..., d]
  12. const scalar_t* __restrict__ input, // [..., 2, d]
  13. const int d) {
  14. const int64_t token_idx = blockIdx.x;
  15. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  16. const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
  17. const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
  18. out[token_idx * d + idx] = ACT_FN(x) * y;
  19. }
  20. }
  21. template <typename T>
  22. __device__ __forceinline__ T silu_kernel(const T& x) {
  23. // x * sigmoid(x)
  24. return (T)(((float)x) / (1.0f + expf((float)-x)));
  25. }
  26. template <typename T>
  27. __device__ __forceinline__ T gelu_kernel(const T& x) {
  28. // Equivalent to PyTorch GELU with 'none' approximation.
  29. // Refer to:
  30. // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
  31. const float f = (float)x;
  32. constexpr float ALPHA = M_SQRT1_2;
  33. return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
  34. }
  35. template <typename T>
  36. __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  37. // Equivalent to PyTorch GELU with `tanh` approximation
  38. const float f = (float)x;
  39. constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
  40. constexpr float KAPPA = 0.044715;
  41. float x_cube = f * f * f;
  42. float inner = BETA * (f + KAPPA * x_cube);
  43. return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
  44. }
  45. } // namespace aphrodite
  46. // Launch activation and gating kernel.
  47. #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
  48. int d = input.size(-1) / 2; \
  49. int64_t num_tokens = input.numel() / input.size(-1); \
  50. dim3 grid(num_tokens); \
  51. dim3 block(std::min(d, 1024)); \
  52. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  53. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  54. APHRODITE_DISPATCH_FLOATING_TYPES( \
  55. input.scalar_type(), "act_and_mul_kernel", [&] { \
  56. aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
  57. <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
  58. input.data_ptr<scalar_t>(), d); \
  59. });
  60. void silu_and_mul(torch::Tensor& out, // [..., d]
  61. torch::Tensor& input) // [..., 2 * d]
  62. {
  63. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel);
  64. }
  65. void gelu_and_mul(torch::Tensor& out, // [..., d]
  66. torch::Tensor& input) // [..., 2 * d]
  67. {
  68. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel);
  69. }
  70. void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
  71. torch::Tensor& input) // [..., 2 * d]
  72. {
  73. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_tanh_kernel);
  74. }
  75. namespace aphrodite {
  76. // Element-wise activation kernel template.
  77. template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  78. __global__ void activation_kernel(
  79. scalar_t* __restrict__ out, // [..., d]
  80. const scalar_t* __restrict__ input, // [..., d]
  81. const int d) {
  82. const int64_t token_idx = blockIdx.x;
  83. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  84. const scalar_t x = APHRODITE_LDG(&input[token_idx * d + idx]);
  85. out[token_idx * d + idx] = ACT_FN(x);
  86. }
  87. }
  88. } // namespace aphrodite
  89. // Launch element-wise activation kernel.
  90. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
  91. int d = input.size(-1); \
  92. int64_t num_tokens = input.numel() / d; \
  93. dim3 grid(num_tokens); \
  94. dim3 block(std::min(d, 1024)); \
  95. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  96. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  97. APHRODITE_DISPATCH_FLOATING_TYPES( \
  98. input.scalar_type(), "activation_kernel", [&] { \
  99. aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>> \
  100. <<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
  101. input.data_ptr<scalar_t>(), d); \
  102. });
  103. namespace aphrodite {
  104. template <typename T>
  105. __device__ __forceinline__ T gelu_new_kernel(const T& x) {
  106. const float x3 = (float)(x * x * x);
  107. const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3))));
  108. return ((T)0.5) * x * (((T)1.0) + t);
  109. }
  110. template <typename T>
  111. __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  112. const float f = (float)x;
  113. const T t =
  114. (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x));
  115. return ((T)0.5) * x * (((T)1.0) + t);
  116. }
  117. template <typename T>
  118. __device__ __forceinline__ T gelu_quick_kernel(const T& x) {
  119. // x * sigmoid(1.702 * x)
  120. return (T)(((float)x) / (1.0f + expf(-1.702f * (float)x)));
  121. }
  122. } // namespace aphrodite
  123. void gelu_new(torch::Tensor& out, // [..., d]
  124. torch::Tensor& input) // [..., d]
  125. {
  126. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_new_kernel);
  127. }
  128. void gelu_fast(torch::Tensor& out, // [..., d]
  129. torch::Tensor& input) // [..., d]
  130. {
  131. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_fast_kernel);
  132. }
  133. void gelu_quick(torch::Tensor& out, // [..., d]
  134. torch::Tensor& input) // [..., d]
  135. {
  136. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_quick_kernel);
  137. }