1
0

activation_kernels.cu 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/extension.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. namespace aphrodite {
  8. // Activation and gating kernel template.
  9. template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  10. __global__ void act_and_mul_kernel(
  11. scalar_t* __restrict__ out, // [..., d]
  12. const scalar_t* __restrict__ input, // [..., 2, d]
  13. const int d) {
  14. const int64_t token_idx = blockIdx.x;
  15. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  16. const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
  17. const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
  18. out[token_idx * d + idx] = ACT_FN(x) * y;
  19. }
  20. }
  21. template<typename T>
  22. __device__ __forceinline__ T silu_kernel(const T& x) {
  23. // x * sigmoid(x)
  24. return (T) (((float) x) / (1.0f + expf((float) -x)));
  25. }
  26. template<typename T>
  27. __device__ __forceinline__ T gelu_kernel(const T& x) {
  28. // Equivalent to PyTorch GELU with 'none' approximation.
  29. // Refer to:
  30. // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L38
  31. const float f = (float) x;
  32. constexpr float ALPHA = M_SQRT1_2;
  33. return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
  34. }
  35. } // namespace aphrodite
  36. // Launch activation and gating kernel.
  37. #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
  38. int d = input.size(-1) / 2; \
  39. int64_t num_tokens = input.numel() / input.size(-1); \
  40. dim3 grid(num_tokens); \
  41. dim3 block(std::min(d, 1024)); \
  42. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  43. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  44. APHRODITE_DISPATCH_FLOATING_TYPES( \
  45. input.scalar_type(), \
  46. "act_and_mul_kernel", \
  47. [&] { \
  48. aphrodite::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
  49. out.data_ptr<scalar_t>(), \
  50. input.data_ptr<scalar_t>(), \
  51. d); \
  52. });
  53. void silu_and_mul(
  54. torch::Tensor& out, // [..., d]
  55. torch::Tensor& input) // [..., 2 * d]
  56. {
  57. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::silu_kernel);
  58. }
  59. void gelu_and_mul(
  60. torch::Tensor& out, // [..., d]
  61. torch::Tensor& input) // [..., 2 * d]
  62. {
  63. LAUNCH_ACTIVATION_GATE_KERNEL(aphrodite::gelu_kernel);
  64. }
  65. namespace aphrodite {
  66. // Element-wise activation kernel template.
  67. template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  68. __global__ void activation_kernel(
  69. scalar_t* __restrict__ out, // [..., d]
  70. const scalar_t* __restrict__ input, // [..., d]
  71. const int d) {
  72. const int64_t token_idx = blockIdx.x;
  73. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  74. const scalar_t x = APHRODITE_LDG(&input[token_idx * d + idx]);
  75. out[token_idx * d + idx] = ACT_FN(x);
  76. }
  77. }
  78. } // namespace aphrodite
  79. // Launch element-wise activation kernel.
  80. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
  81. int d = input.size(-1); \
  82. int64_t num_tokens = input.numel() / d; \
  83. dim3 grid(num_tokens); \
  84. dim3 block(std::min(d, 1024)); \
  85. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  86. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  87. APHRODITE_DISPATCH_FLOATING_TYPES( \
  88. input.scalar_type(), \
  89. "activation_kernel", \
  90. [&] { \
  91. aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
  92. out.data_ptr<scalar_t>(), \
  93. input.data_ptr<scalar_t>(), \
  94. d); \
  95. });
  96. namespace aphrodite {
  97. template<typename T>
  98. __device__ __forceinline__ T gelu_new_kernel(const T& x) {
  99. const float x3 = (float) (x * x * x);
  100. const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
  101. return ((T) 0.5) * x * (((T) 1.0) + t);
  102. }
  103. template<typename T>
  104. __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  105. const float f = (float) x;
  106. const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
  107. return ((T) 0.5) * x * (((T) 1.0) + t);
  108. }
  109. } // namespace aphrodite
  110. void gelu_new(
  111. torch::Tensor& out, // [..., d]
  112. torch::Tensor& input) // [..., d]
  113. {
  114. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_new_kernel);
  115. }
  116. void gelu_fast(
  117. torch::Tensor& out, // [..., d]
  118. torch::Tensor& input) // [..., d]
  119. {
  120. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_fast_kernel);
  121. }