layernorm_kernels.cu 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include "dispatch_utils.h"
  4. #include "reduction.cuh"
  5. namespace aphrodite {
  6. // TODO: Further optimize this kernel.
  7. template<typename scalar_t>
  8. __global__ void rms_norm_kernel(
  9. scalar_t* __restrict__ out, // [..., hidden_size]
  10. const scalar_t* __restrict__ input, // [..., hidden_size]
  11. const scalar_t* __restrict__ weight, // [hidden_size]
  12. const float epsilon,
  13. const int num_tokens,
  14. const int hidden_size) {
  15. __shared__ float s_variance;
  16. float variance = 0.0f;
  17. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  18. const float x = (float) input[blockIdx.x * hidden_size + idx];
  19. variance += x * x;
  20. }
  21. variance = blockReduceSum<float>(variance);
  22. if (threadIdx.x == 0) {
  23. s_variance = rsqrtf(variance / hidden_size + epsilon);
  24. }
  25. __syncthreads();
  26. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  27. float x = (float) input[blockIdx.x * hidden_size + idx];
  28. out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  29. }
  30. }
  31. } // namespace aphrodite
  32. void rms_norm(
  33. torch::Tensor& out, // [..., hidden_size]
  34. torch::Tensor& input, // [..., hidden_size]
  35. torch::Tensor& weight, // [hidden_size]
  36. float epsilon) {
  37. int hidden_size = input.size(-1);
  38. int num_tokens = input.numel() / hidden_size;
  39. dim3 grid(num_tokens);
  40. dim3 block(std::min(hidden_size, 1024));
  41. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  42. APHRODITE_DISPATCH_FLOATING_TYPES(
  43. input.scalar_type(),
  44. "rms_norm_kernel",
  45. [&] {
  46. aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
  47. out.data_ptr<scalar_t>(),
  48. input.data_ptr<scalar_t>(),
  49. weight.data_ptr<scalar_t>(),
  50. epsilon,
  51. num_tokens,
  52. hidden_size);
  53. });
  54. }