math.cuh 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. /*
  2. * Copyright (c) 2024 by PygmalionAI team.
  3. * Copyright (c) 2023 by FlashInfer team.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #ifndef APHRODITE_MATH_CUH_
  18. #define APHRODITE_MATH_CUH_
  19. #include <cuda_fp16.h>
  20. #include <cuda_runtime.h>
  21. namespace aphrodite {
  22. namespace math {
  23. // log2(e)
  24. constexpr float log2e = 1.44269504088896340736f;
  25. __forceinline__ __device__ half2 uint32_as_half2(uint32_t x) {
  26. return *(half2*)&x;
  27. }
  28. __forceinline__ __device__ uint32_t half2_as_uint32(half2 x) {
  29. return *(uint32_t*)&x;
  30. }
  31. /*!
  32. * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x
  33. * \param x input
  34. */
  35. __forceinline__ __device__ float ptx_exp2(float x) {
  36. float y;
  37. asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
  38. return y;
  39. }
  40. /*!
  41. * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x)
  42. * \param x input
  43. */
  44. __forceinline__ __device__ float ptx_log2(float x) {
  45. float y;
  46. asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
  47. return y;
  48. }
  49. /*!
  50. * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x
  51. * \param x input
  52. */
  53. __forceinline__ __device__ half2 ptx_exp2(half2 x) {
  54. uint32_t y_u32;
  55. uint32_t x_u32 = half2_as_uint32(x);
  56. asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
  57. return uint32_as_half2(y_u32);
  58. }
  59. /*!
  60. * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x
  61. * \param x input
  62. */
  63. __forceinline__ __device__ half ptx_exp2(half x) {
  64. ushort y_u16;
  65. asm volatile("ex2.approx.f16 %0, %1;"
  66. : "=h"(y_u16)
  67. : "h"(__half_as_ushort(x)));
  68. return __ushort_as_half(y_u16);
  69. }
  70. /*!
  71. * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x
  72. * \param x input
  73. */
  74. __forceinline__ __device__ float ptx_rcp(float x) {
  75. float y;
  76. asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
  77. return y;
  78. }
  79. /*!
  80. * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly
  81. * shuffle between threads in a warp. \param x The value in the source lane
  82. * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^
  83. * delta]
  84. */
  85. __forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) {
  86. float y;
  87. asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;"
  88. : "=f"(y)
  89. : "f"(x), "r"(lane_mask));
  90. return y;
  91. }
  92. /*!
  93. * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a
  94. * butterfly shuffle between threads in a warp. \param x The value in the source
  95. * lane \param lane_mask The mask to perform thread index xor with: y[i] <- x[i
  96. * ^ lane_mask]
  97. */
  98. __forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) {
  99. return __shfl_xor_sync(0xffffffff, x, lane_mask);
  100. }
  101. /*!
  102. * \brief Wrapper of PTX rsqrt approximation instruction, which computes
  103. * 1/sqrt(x) \param x input
  104. */
  105. __forceinline__ __device__ float rsqrt(float x) {
  106. float y;
  107. asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
  108. return y;
  109. }
  110. /*!
  111. * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x)
  112. * \param x input
  113. */
  114. __forceinline__ __device__ float tanh(float x) {
  115. float y;
  116. asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
  117. return y;
  118. }
  119. /*!
  120. * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x)
  121. * \param x input
  122. */
  123. __forceinline__ __device__ half2 tanh(half2 x) {
  124. uint32_t y_u32;
  125. uint32_t x_u32 = half2_as_uint32(x);
  126. asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
  127. return uint32_as_half2(y_u32);
  128. }
  129. /*!
  130. * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x)
  131. * \param x input
  132. */
  133. __forceinline__ __device__ half tanh(half x) {
  134. ushort y_u16;
  135. asm volatile("tanh.approx.f16 %0, %1;"
  136. : "=h"(y_u16)
  137. : "h"(__half_as_ushort(x)));
  138. return __ushort_as_half(y_u16);
  139. }
  140. } // namespace math
  141. } // namespace aphrodite
  142. #endif // APHRODITE_MATH_CUH_