activation_kernels.cu 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include "dispatch_utils.h"
  4. namespace aphrodite {
  5. template<typename T>
  6. __device__ __forceinline__ T silu(const T& x) {
  7. // x * sigmoid(x)
  8. return (T) (((float) x) / (1.0f + expf((float) -x)));
  9. }
  10. template<typename scalar_t>
  11. __global__ void silu_and_mul_kernel(
  12. scalar_t* __restrict__ out, // [..., d]
  13. const scalar_t* __restrict__ input, // [..., 2, d]
  14. const int d) {
  15. const int token_idx = blockIdx.x;
  16. for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
  17. const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
  18. const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
  19. out[token_idx * d + idx] = silu(x) * y;
  20. }
  21. }
  22. } // namespace aphrodite
  23. void silu_and_mul(
  24. torch::Tensor& out, // [..., d]
  25. torch::Tensor& input) // [..., 2 * d]
  26. {
  27. int num_tokens = input.numel() / input.size(-1);
  28. int d = input.size(-1) / 2;
  29. dim3 grid(num_tokens);
  30. dim3 block(std::min(d, 1024));
  31. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  32. APHRODITE_DISPATCH_FLOATING_TYPES(
  33. input.scalar_type(),
  34. "silu_and_mul_kernel",
  35. [&] {
  36. aphrodite::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
  37. out.data_ptr<scalar_t>(),
  38. input.data_ptr<scalar_t>(),
  39. d);
  40. });
  41. }
  42. namespace aphrodite {
  43. // Element-wise activation kernel template.
  44. template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  45. __global__ void activation_kernel(
  46. scalar_t* __restrict__ out, // [..., d]
  47. const scalar_t* __restrict__ input, // [..., d]
  48. const int d) {
  49. const int token_idx = blockIdx.x;
  50. for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
  51. const scalar_t x = __ldg(&input[token_idx * d + idx]);
  52. out[token_idx * d + idx] = ACT_FN(x);
  53. }
  54. }
  55. } // namespace aphrodite
  56. // Launch element-wise activation kernel.
  57. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
  58. int d = input.size(-1); \
  59. int num_tokens = input.numel() / d; \
  60. dim3 grid(num_tokens); \
  61. dim3 block(std::min(d, 1024)); \
  62. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  63. APHRODITE_DISPATCH_FLOATING_TYPES( \
  64. input.scalar_type(), \
  65. "activation_kernel", \
  66. [&] { \
  67. aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
  68. out.data_ptr<scalar_t>(), \
  69. input.data_ptr<scalar_t>(), \
  70. d); \
  71. });
  72. namespace aphrodite {
  73. template<typename T>
  74. __device__ __forceinline__ T gelu_new_kernel(const T& x) {
  75. const float x3 = (float) (x * x * x);
  76. const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
  77. return ((T) 0.5) * x * (((T) 1.0) + t);
  78. }
  79. template<typename T>
  80. __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  81. const float f = (float) x;
  82. const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
  83. return ((T) 0.5) * x * (((T) 1.0) + t);
  84. }
  85. } // namespace aphrodite
  86. void gelu_new(
  87. torch::Tensor& out, // [..., d]
  88. torch::Tensor& input) // [..., d]
  89. {
  90. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_new_kernel);
  91. }
  92. void gelu_fast(
  93. torch::Tensor& out, // [..., d]
  94. torch::Tensor& input) // [..., d]
  95. {
  96. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_fast_kernel);
  97. }