activation_kernels.cu 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. #include <torch/extension.h>
  2. #include <c10/cuda/CUDAGuard.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include "cuda_compat.h"
  5. #include "dispatch_utils.h"
  6. namespace aphrodite {
  7. template<typename T>
  8. __device__ __forceinline__ T silu(const T& x) {
  9. // x * sigmoid(x)
  10. return (T) (((float) x) / (1.0f + expf((float) -x)));
  11. }
  12. template<typename scalar_t>
  13. __global__ void silu_and_mul_kernel(
  14. scalar_t* __restrict__ out, // [..., d]
  15. const scalar_t* __restrict__ input, // [..., 2, d]
  16. const int d) {
  17. const int64_t token_idx = blockIdx.x;
  18. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  19. const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
  20. const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
  21. out[token_idx * d + idx] = silu(x) * y;
  22. }
  23. }
  24. } // namespace aphrodite
  25. void silu_and_mul(
  26. torch::Tensor& out, // [..., d]
  27. torch::Tensor& input) // [..., 2 * d]
  28. {
  29. int64_t num_tokens = input.numel() / input.size(-1);
  30. int d = input.size(-1) / 2;
  31. dim3 grid(num_tokens);
  32. dim3 block(std::min(d, 1024));
  33. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  34. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  35. APHRODITE_DISPATCH_FLOATING_TYPES(
  36. input.scalar_type(),
  37. "silu_and_mul_kernel",
  38. [&] {
  39. aphrodite::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
  40. out.data_ptr<scalar_t>(),
  41. input.data_ptr<scalar_t>(),
  42. d);
  43. });
  44. }
  45. namespace aphrodite {
  46. // Element-wise activation kernel template.
  47. template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
  48. __global__ void activation_kernel(
  49. scalar_t* __restrict__ out, // [..., d]
  50. const scalar_t* __restrict__ input, // [..., d]
  51. const int d) {
  52. const int64_t token_idx = blockIdx.x;
  53. for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
  54. const scalar_t x = APHRODITE_LDG(&input[token_idx * d + idx]);
  55. out[token_idx * d + idx] = ACT_FN(x);
  56. }
  57. }
  58. } // namespace aphrodite
  59. // Launch element-wise activation kernel.
  60. #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
  61. int d = input.size(-1); \
  62. int64_t num_tokens = input.numel() / d; \
  63. dim3 grid(num_tokens); \
  64. dim3 block(std::min(d, 1024)); \
  65. const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
  66. const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
  67. APHRODITE_DISPATCH_FLOATING_TYPES( \
  68. input.scalar_type(), \
  69. "activation_kernel", \
  70. [&] { \
  71. aphrodite::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
  72. out.data_ptr<scalar_t>(), \
  73. input.data_ptr<scalar_t>(), \
  74. d); \
  75. });
  76. namespace aphrodite {
  77. template<typename T>
  78. __device__ __forceinline__ T gelu_new_kernel(const T& x) {
  79. const float x3 = (float) (x * x * x);
  80. const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
  81. return ((T) 0.5) * x * (((T) 1.0) + t);
  82. }
  83. template<typename T>
  84. __device__ __forceinline__ T gelu_fast_kernel(const T& x) {
  85. const float f = (float) x;
  86. const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
  87. return ((T) 0.5) * x * (((T) 1.0) + t);
  88. }
  89. } // namespace aphrodite
  90. void gelu_new(
  91. torch::Tensor& out, // [..., d]
  92. torch::Tensor& input) // [..., d]
  93. {
  94. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_new_kernel);
  95. }
  96. void gelu_fast(
  97. torch::Tensor& out, // [..., d]
  98. torch::Tensor& input) // [..., d]
  99. {
  100. LAUNCH_ACTIVATION_KERNEL(aphrodite::gelu_fast_kernel);
  101. }