activation.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. #include "cpu_types.hpp"
  2. namespace {
  3. template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8 &),
  4. bool is_gated>
  5. void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
  6. scalar_t *__restrict__ output) {
  7. using scalar_vec_t = vec_op::vec_t<scalar_t>;
  8. constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  9. TORCH_CHECK(d % VEC_ELEM_NUM == 0);
  10. #pragma omp parallel for
  11. for (int i = 0; i < num_tokens; ++i) {
  12. for (int j = 0; j < d; j += VEC_ELEM_NUM) {
  13. int start = i * d;
  14. if constexpr (is_gated) {
  15. start *= 2;
  16. }
  17. const scalar_vec_t x(input + start + j);
  18. const vec_op::FP32Vec8 f32_x(x);
  19. vec_op::FP32Vec8 f32_ans = func(f32_x);
  20. if constexpr (is_gated) {
  21. const scalar_vec_t y(input + start + d + j);
  22. const vec_op::FP32Vec8 f32_y(y);
  23. f32_ans = f32_y * f32_ans;
  24. }
  25. const scalar_vec_t result(f32_ans);
  26. result.save(output + i * d + j);
  27. }
  28. }
  29. }
  30. FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8 &x) {
  31. const vec_op::FP32Vec8 zeros(0.0);
  32. const vec_op::FP32Vec8 ones(1.0);
  33. return x / (ones + (zeros - x).exp());
  34. }
  35. FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8 &x) {
  36. const vec_op::FP32Vec8 ones(1.0);
  37. const vec_op::FP32Vec8 w1(0.79788456f);
  38. const vec_op::FP32Vec8 w2(0.044715f);
  39. const vec_op::FP32Vec8 w3(0.5);
  40. const vec_op::FP32Vec8 x3 = x * x * x;
  41. const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
  42. return w3 * x * (ones + t);
  43. }
  44. FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8 &x) {
  45. const vec_op::FP32Vec8 ones(1.0);
  46. const vec_op::FP32Vec8 w1(0.79788456f);
  47. const vec_op::FP32Vec8 w2(0.044715f);
  48. const vec_op::FP32Vec8 w3(0.5);
  49. const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
  50. return w3 * x * (ones + t);
  51. }
  52. FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8 &x) {
  53. const vec_op::FP32Vec8 ones(1.0);
  54. const vec_op::FP32Vec8 w1(M_SQRT1_2);
  55. const vec_op::FP32Vec8 w2(0.5);
  56. return x * w2 * (ones + (x * w1).er());
  57. }
  58. FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8 &x) {
  59. const vec_op::FP32Vec8 ones(1.0);
  60. const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
  61. const vec_op::FP32Vec8 w2(0.5);
  62. const vec_op::FP32Vec8 w3(0.044715);
  63. const vec_op::FP32Vec8 x_3 = x * x * x;
  64. const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
  65. return x * w2 * (ones + inner.tanh());
  66. }
  67. }; // namespace
  68. void silu_and_mul(torch::Tensor &out, torch::Tensor &input) {
  69. int num_tokens = input.numel() / input.size(-1);
  70. int d = input.size(-1) / 2;
  71. APHRODITE_DISPATCH_FLOATING_TYPES(
  72. input.scalar_type(), "silu_and_mul_impl", [&] {
  73. CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
  74. activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
  75. input.data_ptr<scalar_t>(),
  76. out.data_ptr<scalar_t>());
  77. CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
  78. });
  79. }
  80. void gelu_and_mul(torch::Tensor &out, // [..., d]
  81. torch::Tensor &input) // [..., 2 * d]
  82. {
  83. int num_tokens = input.numel() / input.size(-1);
  84. int d = input.size(-1) / 2;
  85. APHRODITE_DISPATCH_FLOATING_TYPES(
  86. input.scalar_type(), "gelu_and_mul_impl", [&] {
  87. CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
  88. activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
  89. input.data_ptr<scalar_t>(),
  90. out.data_ptr<scalar_t>());
  91. CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
  92. });
  93. }
  94. void gelu_tanh_and_mul(torch::Tensor &out, // [..., d]
  95. torch::Tensor &input) // [..., 2 * d]
  96. {
  97. int num_tokens = input.numel() / input.size(-1);
  98. int d = input.size(-1) / 2;
  99. APHRODITE_DISPATCH_FLOATING_TYPES(
  100. input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
  101. CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
  102. activation_kernel<scalar_t, gelu_tanh_act, true>(
  103. num_tokens, d, input.data_ptr<scalar_t>(),
  104. out.data_ptr<scalar_t>());
  105. CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
  106. });
  107. }
  108. void gelu_new(torch::Tensor &out, torch::Tensor &input) {
  109. int num_tokens = input.numel() / input.size(-1);
  110. int d = input.size(-1);
  111. APHRODITE_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
  112. CPU_KERNEL_GUARD_IN(gelu_new_impl)
  113. activation_kernel<scalar_t, gelu_new_act, false>(
  114. num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
  115. CPU_KERNEL_GUARD_OUT(gelu_new_impl)
  116. });
  117. }
  118. void gelu_fast(torch::Tensor &out, torch::Tensor &input) {
  119. int num_tokens = input.numel() / input.size(-1);
  120. int d = input.size(-1);
  121. APHRODITE_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
  122. CPU_KERNEL_GUARD_IN(gelu_fast_impl)
  123. activation_kernel<scalar_t, gelu_fast_act, false>(
  124. num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
  125. CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
  126. });
  127. }