1
0

activation_kernels.cu 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/extension.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. namespace aphrodite {
  8. // Activation and gating kernel template.
  9. template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  10. __global__ void act_and_mul_kernel(
  11. scalar_t* __restrict__ out, // [..., d]
  12. const scalar_t* __restrict__ input, // [..., 2, d]
  13. const int d) {
  14. const int64_t token_idx = blockIdx.x;
  15. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  16. const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
  17. const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
  18. out[token_idx * d + idx] = ACT_FN(x) * y;
  19. }
  20. }
  21. template<typename T>
  22. __device__ __forceinline__ T silu_kernel(const T& x) {
  23. // x * sigmoid(x)
  24. return (T) (((float) x) / (1.0f + expf((float) -x)));
  25. }
  26. template<typename T>
  27. __device__ __forceinline__ T gelu_kernel(const T& x) {
  28. // Equivalent to PyTorch GELU with 'none' approximation.
  29. // Refer to:
  30. // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
  31. const float f = (float) x;
  32. constexpr float ALPHA = M_SQRT1_2;
  33. return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
  34. }
  35. template<typename T>
  36. __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
  37. // Equivalent to PyTorch GELU with `tanh` approximation
  38. const float f = (float) x;
  39. constexpr float BETA = M_SQRT2 * M_2_SQRTPI / 0.5f;
  40. constexpr float KAPPA = 0.044715;
  41. float x_cube = f * f * f;
  42. float inner = BETA * (f + KAPPA * x_cube);
  43. return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
  44. }
  45. } // namespace aphrodite
  46. // Launch activation and gating kernel.
  47. #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
  48. int d = input.size(-1) / 2; \
  49. int64_t num_tokens = input.numel() / input.size(-1); \
  50. dim3 grid(num_tokens); \
  51. dim3 block(std::min(d, 1024)); \
  52. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  53. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  54. APHRODITE_DISPATCH_FLOATING_TYPES( \
  55. input.scalar_type(), \
  56. "act_and_mul_kernel", \
  57. [&] { \
  58. aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
  59. out.data_ptr<scalar_t>(), \
  60. input.data_ptr<scalar_t>(), \
  61. d); \
  62. });
  63. void silu_and_mul(
  64. torch::Tensor& out, // [..., d]
  65. torch::Tensor& input) // [..., 2 * d]
  66. {
  67. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel);
  68. }
  69. void gelu_and_mul(
  70. torch::Tensor& out, // [..., d]
  71. torch::Tensor& input) // [..., 2 * d]
  72. {
  73. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel);
  74. }
  75. void gelu_tanh_and_mul(
  76. torch::Tensor& out, // [..., d]
  77. torch::Tensor& input) // [..., 2 * d]
  78. {
  79. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_tanh_kernel);
  80. }
  81. namespace aphrodite {
  82. // Element-wise activation kernel template.
  83. template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  84. __global__ void activation_kernel(
  85. scalar_t* __restrict__ out, // [..., d]
  86. const scalar_t* __restrict__ input, // [..., d]
  87. const int d) {
  88. const int64_t token_idx = blockIdx.x;
  89. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  90. const scalar_t x = APHRODITE_LDG(&input[token_idx * d + idx]);
  91. out[token_idx * d + idx] = ACT_FN(x);
  92. }
  93. }
  94. } // namespace aphrodite
  95. // Launch element-wise activation kernel.
  96. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
  97. int d = input.size(-1); \
  98. int64_t num_tokens = input.numel() / d; \
  99. dim3 grid(num_tokens); \
  100. dim3 block(std::min(d, 1024)); \
  101. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  102. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  103. APHRODITE_DISPATCH_FLOATING_TYPES( \
  104. input.scalar_type(), \
  105. "activation_kernel", \
  106. [&] { \
  107. aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
  108. out.data_ptr<scalar_t>(), \
  109. input.data_ptr<scalar_t>(), \
  110. d); \
  111. });
  112. namespace aphrodite {
  113. template<typename T>
  114. __device__ __forceinline__ T gelu_new_kernel(const T& x) {
  115. const float x3 = (float) (x * x * x);
  116. const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
  117. return ((T) 0.5) * x * (((T) 1.0) + t);
  118. }
  119. template<typename T>
  120. __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  121. const float f = (float) x;
  122. const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
  123. return ((T) 0.5) * x * (((T) 1.0) + t);
  124. }
  125. } // namespace aphrodite
  126. void gelu_new(
  127. torch::Tensor& out, // [..., d]
  128. torch::Tensor& input) // [..., d]
  129. {
  130. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_new_kernel);
  131. }
  132. void gelu_fast(
  133. torch::Tensor& out, // [..., d]
  134. torch::Tensor& input) // [..., d]
  135. {
  136. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_fast_kernel);
  137. }