layernorm_kernels.cu 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "dispatch_utils.h"
  5. #include "reduction.cuh"
  6. namespace aphrodite {
  7. // TODO: Further optimize this kernel.
  8. template<typename scalar_t>
  9. __global__ void rms_norm_kernel(
  10. scalar_t* __restrict__ out, // [..., hidden_size]
  11. const scalar_t* __restrict__ input, // [..., hidden_size]
  12. const scalar_t* __restrict__ weight, // [hidden_size]
  13. const float epsilon,
  14. const int num_tokens,
  15. const int hidden_size) {
  16. __shared__ float s_variance;
  17. float variance = 0.0f;
  18. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  19. const float x = (float) input[blockIdx.x * hidden_size + idx];
  20. variance += x * x;
  21. }
  22. variance = blockReduceSum<float>(variance);
  23. if (threadIdx.x == 0) {
  24. s_variance = rsqrtf(variance / hidden_size + epsilon);
  25. }
  26. __syncthreads();
  27. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  28. float x = (float) input[blockIdx.x * hidden_size + idx];
  29. out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  30. }
  31. }
  32. // TODO: Further optimize this kernel.
  33. template<typename scalar_t>
  34. __global__ void fused_add_rms_norm_kernel(
  35. scalar_t* __restrict__ input, // [..., hidden_size]
  36. scalar_t* __restrict__ residual, // [..., hidden_size]
  37. const scalar_t* __restrict__ weight, // [hidden_size]
  38. const float epsilon,
  39. const int num_tokens,
  40. const int hidden_size) {
  41. __shared__ float s_variance;
  42. float variance = 0.0f;
  43. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  44. float x = (float) input[blockIdx.x * hidden_size + idx];
  45. x += (float) residual[blockIdx.x * hidden_size + idx];
  46. variance += x * x;
  47. residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
  48. }
  49. variance = blockReduceSum<float>(variance);
  50. if (threadIdx.x == 0) {
  51. s_variance = rsqrtf(variance / hidden_size + epsilon);
  52. }
  53. __syncthreads();
  54. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  55. float x = (float) residual[blockIdx.x * hidden_size + idx];
  56. input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  57. }
  58. }
  59. } // namespace aphrodite
  60. void rms_norm(
  61. torch::Tensor& out, // [..., hidden_size]
  62. torch::Tensor& input, // [..., hidden_size]
  63. torch::Tensor& weight, // [hidden_size]
  64. float epsilon) {
  65. int hidden_size = input.size(-1);
  66. int num_tokens = input.numel() / hidden_size;
  67. dim3 grid(num_tokens);
  68. dim3 block(std::min(hidden_size, 1024));
  69. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  70. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  71. APHRODITE_DISPATCH_FLOATING_TYPES(
  72. input.scalar_type(),
  73. "rms_norm_kernel",
  74. [&] {
  75. aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
  76. out.data_ptr<scalar_t>(),
  77. input.data_ptr<scalar_t>(),
  78. weight.data_ptr<scalar_t>(),
  79. epsilon,
  80. num_tokens,
  81. hidden_size);
  82. });
  83. }
  84. void fused_add_rms_norm(
  85. torch::Tensor& input, // [..., hidden_size]
  86. torch::Tensor& residual, // [..., hidden_size]
  87. torch::Tensor& weight, // [hidden_size]
  88. float epsilon) {
  89. int hidden_size = input.size(-1);
  90. int num_tokens = input.numel() / hidden_size;
  91. dim3 grid(num_tokens);
  92. dim3 block(std::min(hidden_size, 1024));
  93. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  94. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  95. APHRODITE_DISPATCH_FLOATING_TYPES(
  96. input.scalar_type(),
  97. "fused_add_rms_norm_kernel",
  98. [&] {
  99. aphrodite::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
  100. input.data_ptr<scalar_t>(),
  101. residual.data_ptr<scalar_t>(),
  102. weight.data_ptr<scalar_t>(),
  103. epsilon,
  104. num_tokens,
  105. hidden_size);
  106. });
  107. }