activation.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #include "cpu_types.hpp"
  2. namespace {
  3. template <typename scalar_t, vec_op::FP32Vec8 (*func)(const vec_op::FP32Vec8&),
  4. bool is_gated>
  5. void activation_kernel(int num_tokens, int d, scalar_t* __restrict__ input,
  6. scalar_t* __restrict__ output) {
  7. using scalar_vec_t = vec_op::vec_t<scalar_t>;
  8. constexpr int VEC_ELEM_NUM = scalar_vec_t::get_elem_num();
  9. TORCH_CHECK(d % VEC_ELEM_NUM == 0);
  10. #pragma omp parallel for
  11. for (int i = 0; i < num_tokens; ++i) {
  12. for (int j = 0; j < d; j += VEC_ELEM_NUM) {
  13. int start = i * d;
  14. if constexpr (is_gated) {
  15. start *= 2;
  16. }
  17. const scalar_vec_t x(input + start + j);
  18. const vec_op::FP32Vec8 f32_x(x);
  19. vec_op::FP32Vec8 f32_ans = func(f32_x);
  20. if constexpr (is_gated) {
  21. const scalar_vec_t y(input + start + d + j);
  22. const vec_op::FP32Vec8 f32_y(y);
  23. f32_ans = f32_y * f32_ans;
  24. }
  25. const scalar_vec_t result(f32_ans);
  26. result.save(output + i * d + j);
  27. }
  28. }
  29. }
  30. FORCE_INLINE vec_op::FP32Vec8 silu_act(const vec_op::FP32Vec8& x) {
  31. const vec_op::FP32Vec8 zeros(0.0);
  32. const vec_op::FP32Vec8 ones(1.0);
  33. return x / (ones + (zeros - x).exp());
  34. }
  35. FORCE_INLINE vec_op::FP32Vec8 gelu_new_act(const vec_op::FP32Vec8& x) {
  36. const vec_op::FP32Vec8 ones(1.0);
  37. const vec_op::FP32Vec8 w1(0.79788456f);
  38. const vec_op::FP32Vec8 w2(0.044715f);
  39. const vec_op::FP32Vec8 w3(0.5);
  40. const vec_op::FP32Vec8 x3 = x * x * x;
  41. const vec_op::FP32Vec8 t = (w1 * (x + w2 * x3)).tanh();
  42. return w3 * x * (ones + t);
  43. }
  44. FORCE_INLINE vec_op::FP32Vec8 gelu_fast_act(const vec_op::FP32Vec8& x) {
  45. const vec_op::FP32Vec8 ones(1.0);
  46. const vec_op::FP32Vec8 w1(0.79788456f);
  47. const vec_op::FP32Vec8 w2(0.044715f);
  48. const vec_op::FP32Vec8 w3(0.5);
  49. const vec_op::FP32Vec8 t = (x * w1 * (ones + x * w2 * x)).tanh();
  50. return w3 * x * (ones + t);
  51. }
  52. FORCE_INLINE vec_op::FP32Vec8 gelu_quick_act(const vec_op::FP32Vec8& x) {
  53. const vec_op::FP32Vec8 zeros(0.0);
  54. const vec_op::FP32Vec8 ones(1.0);
  55. const vec_op::FP32Vec8 w1(1.702f);
  56. return x / (ones + (zeros - w1 * x).exp());
  57. }
  58. FORCE_INLINE vec_op::FP32Vec8 gelu_act(const vec_op::FP32Vec8& x) {
  59. const vec_op::FP32Vec8 ones(1.0);
  60. const vec_op::FP32Vec8 w1(M_SQRT1_2);
  61. const vec_op::FP32Vec8 w2(0.5);
  62. return x * w2 * (ones + (x * w1).er());
  63. }
  64. FORCE_INLINE vec_op::FP32Vec8 gelu_tanh_act(const vec_op::FP32Vec8& x) {
  65. const vec_op::FP32Vec8 ones(1.0);
  66. const vec_op::FP32Vec8 w1(M_SQRT2 * M_2_SQRTPI * 0.5);
  67. const vec_op::FP32Vec8 w2(0.5);
  68. const vec_op::FP32Vec8 w3(0.044715);
  69. const vec_op::FP32Vec8 x_3 = x * x * x;
  70. const vec_op::FP32Vec8 inner = w1 * (x + x_3 * w3);
  71. return x * w2 * (ones + inner.tanh());
  72. }
  73. }; // namespace
  74. void silu_and_mul(torch::Tensor& out, torch::Tensor& input) {
  75. int num_tokens = input.numel() / input.size(-1);
  76. int d = input.size(-1) / 2;
  77. APHRODITE_DISPATCH_FLOATING_TYPES(
  78. input.scalar_type(), "silu_and_mul_impl", [&] {
  79. CPU_KERNEL_GUARD_IN(silu_and_mul_impl)
  80. activation_kernel<scalar_t, silu_act, true>(num_tokens, d,
  81. input.data_ptr<scalar_t>(),
  82. out.data_ptr<scalar_t>());
  83. CPU_KERNEL_GUARD_OUT(silu_and_mul_impl)
  84. });
  85. }
  86. void gelu_and_mul(torch::Tensor& out, // [..., d]
  87. torch::Tensor& input) // [..., 2 * d]
  88. {
  89. int num_tokens = input.numel() / input.size(-1);
  90. int d = input.size(-1) / 2;
  91. APHRODITE_DISPATCH_FLOATING_TYPES(
  92. input.scalar_type(), "gelu_and_mul_impl", [&] {
  93. CPU_KERNEL_GUARD_IN(gelu_and_mul_impl)
  94. activation_kernel<scalar_t, gelu_act, true>(num_tokens, d,
  95. input.data_ptr<scalar_t>(),
  96. out.data_ptr<scalar_t>());
  97. CPU_KERNEL_GUARD_OUT(gelu_and_mul_impl)
  98. });
  99. }
  100. void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
  101. torch::Tensor& input) // [..., 2 * d]
  102. {
  103. int num_tokens = input.numel() / input.size(-1);
  104. int d = input.size(-1) / 2;
  105. APHRODITE_DISPATCH_FLOATING_TYPES(
  106. input.scalar_type(), "gelu_tanh_and_mul_impl", [&] {
  107. CPU_KERNEL_GUARD_IN(gelu_tanh_and_mul_impl)
  108. activation_kernel<scalar_t, gelu_tanh_act, true>(
  109. num_tokens, d, input.data_ptr<scalar_t>(),
  110. out.data_ptr<scalar_t>());
  111. CPU_KERNEL_GUARD_OUT(gelu_tanh_and_mul_impl)
  112. });
  113. }
  114. void gelu_new(torch::Tensor& out, torch::Tensor& input) {
  115. int num_tokens = input.numel() / input.size(-1);
  116. int d = input.size(-1);
  117. APHRODITE_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_new_impl", [&] {
  118. CPU_KERNEL_GUARD_IN(gelu_new_impl)
  119. activation_kernel<scalar_t, gelu_new_act, false>(
  120. num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
  121. CPU_KERNEL_GUARD_OUT(gelu_new_impl)
  122. });
  123. }
  124. void gelu_fast(torch::Tensor& out, torch::Tensor& input) {
  125. int num_tokens = input.numel() / input.size(-1);
  126. int d = input.size(-1);
  127. APHRODITE_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_fast_impl", [&] {
  128. CPU_KERNEL_GUARD_IN(gelu_fast_impl)
  129. activation_kernel<scalar_t, gelu_fast_act, false>(
  130. num_tokens, d, input.data_ptr<scalar_t>(), out.data_ptr<scalar_t>());
  131. CPU_KERNEL_GUARD_OUT(gelu_fast_impl)
  132. });
  133. }
  134. void gelu_quick(torch::Tensor& out, torch::Tensor& input) {
  135. int num_tokens = input.numel() / input.size(-1);
  136. int d = input.size(-1);
  137. APHRODITE_DISPATCH_FLOATING_TYPES(
  138. input.scalar_type(), "gelu_quick_impl", [&] {
  139. CPU_KERNEL_GUARD_IN(gelu_quick_impl)
  140. activation_kernel<scalar_t, gelu_quick_act, false>(
  141. num_tokens, d, input.data_ptr<scalar_t>(),
  142. out.data_ptr<scalar_t>());
  143. CPU_KERNEL_GUARD_OUT(gelu_quick_impl)
  144. });
  145. }