layernorm.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #include "cpu_types.hpp"
  2. namespace {
  3. template <typename scalar_t>
  4. void rms_norm_impl(scalar_t* __restrict__ out,
  5. const scalar_t* __restrict__ input,
  6. const scalar_t* __restrict__ weight, const float epsilon,
  7. const int num_tokens, const int hidden_size) {
  8. using scalar_vec_t = vec_op::vec_t<scalar_t>;
  9. constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  10. TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
  11. #pragma omp parallel for
  12. for (int i = 0; i < num_tokens; ++i) {
  13. vec_op::FP32Vec8 variance(0.0);
  14. auto input_p = input + i * hidden_size;
  15. auto output_p = out + i * hidden_size;
  16. for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
  17. scalar_vec_t x(input_p + j);
  18. vec_op::FP32Vec8 fp32_x(x);
  19. variance = variance + fp32_x * fp32_x;
  20. }
  21. float s_variance =
  22. 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
  23. vec_op::FP32Vec8 fp32_s_variance(s_variance);
  24. for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
  25. scalar_vec_t x(input_p + j);
  26. scalar_vec_t w(weight + j);
  27. vec_op::FP32Vec8 fp32_x(x);
  28. vec_op::FP32Vec8 fp32_w(w);
  29. vec_op::FP32Vec8 fp32_out = fp32_x * fp32_s_variance * fp32_w;
  30. scalar_vec_t out(fp32_out);
  31. out.save(output_p + j);
  32. }
  33. }
  34. }
  35. template <typename scalar_t>
  36. void fused_add_rms_norm_impl(scalar_t* __restrict__ input,
  37. scalar_t* __restrict__ residual,
  38. const scalar_t* __restrict__ weight,
  39. const float epsilon, const int num_tokens,
  40. const int hidden_size) {
  41. using scalar_vec_t = vec_op::vec_t<scalar_t>;
  42. constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  43. TORCH_CHECK(hidden_size % VEC_ELEM_NUM == 0);
  44. #pragma omp parallel for
  45. for (int i = 0; i < num_tokens; ++i) {
  46. vec_op::FP32Vec8 variance(0.0);
  47. auto input_p = input + i * hidden_size;
  48. auto residual_p = residual + i * hidden_size;
  49. for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
  50. scalar_vec_t x(input_p + j);
  51. scalar_vec_t res(residual_p + j);
  52. vec_op::FP32Vec8 fp32_x(x);
  53. vec_op::FP32Vec8 fp32_res(res);
  54. fp32_x = fp32_x + fp32_res;
  55. variance = variance + fp32_x * fp32_x;
  56. scalar_vec_t out(fp32_x);
  57. out.save(residual_p + j);
  58. }
  59. float s_variance =
  60. 1.0f / sqrtf(variance.reduce_sum() / (float)hidden_size + epsilon);
  61. vec_op::FP32Vec8 fp32_s_variance(s_variance);
  62. for (int j = 0; j < hidden_size; j += VEC_ELEM_NUM) {
  63. scalar_vec_t w(weight + j);
  64. scalar_vec_t res(residual_p + j);
  65. vec_op::FP32Vec8 fp32_w(w);
  66. vec_op::FP32Vec8 fp32_res(res);
  67. vec_op::FP32Vec8 fp32_out = fp32_res * fp32_s_variance * fp32_w;
  68. scalar_vec_t out(fp32_out);
  69. out.save(input_p + j);
  70. }
  71. }
  72. }
  73. } // namespace
  74. void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
  75. double epsilon) {
  76. int hidden_size = input.size(-1);
  77. int num_tokens = input.numel() / hidden_size;
  78. APHRODITE_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_impl", [&] {
  79. CPU_KERNEL_GUARD_IN(rms_norm_impl)
  80. rms_norm_impl(out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
  81. weight.data_ptr<scalar_t>(), epsilon, num_tokens,
  82. hidden_size);
  83. CPU_KERNEL_GUARD_OUT(rms_norm_impl)
  84. });
  85. }
  86. void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
  87. torch::Tensor& weight, double epsilon) {
  88. int hidden_size = input.size(-1);
  89. int num_tokens = input.numel() / hidden_size;
  90. APHRODITE_DISPATCH_FLOATING_TYPES(
  91. input.scalar_type(), "fused_add_rms_norm_impl", [&] {
  92. CPU_KERNEL_GUARD_IN(fused_add_rms_norm_impl)
  93. fused_add_rms_norm_impl(
  94. input.data_ptr<scalar_t>(), residual.data_ptr<scalar_t>(),
  95. weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
  96. CPU_KERNEL_GUARD_OUT(fused_add_rms_norm_impl)
  97. });
  98. }