/* * Copyright (c) 2024 by PygmalionAI team. * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef APHRODITE_MATH_CUH_ #define APHRODITE_MATH_CUH_ #include #include namespace aphrodite { namespace math { // log2(e) constexpr float log2e = 1.44269504088896340736f; __forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { return *(half2*)&x; } __forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { return *(uint32_t*)&x; } /*! * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x * \param x input */ __forceinline__ __device__ float ptx_exp2(float x) { float y; asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); return y; } /*! * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x) * \param x input */ __forceinline__ __device__ float ptx_log2(float x) { float y; asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); return y; } /*! * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x * \param x input */ __forceinline__ __device__ half2 ptx_exp2(half2 x) { uint32_t y_u32; uint32_t x_u32 = half2_as_uint32(x); asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); return uint32_as_half2(y_u32); } /*! * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x * \param x input */ __forceinline__ __device__ half ptx_exp2(half x) { ushort y_u16; asm volatile("ex2.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); return __ushort_as_half(y_u16); } /*! * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x * \param x input */ __forceinline__ __device__ float ptx_rcp(float x) { float y; asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); return y; } /*! * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly * shuffle between threads in a warp. \param x The value in the source lane * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ * delta] */ __forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { float y; asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;" : "=f"(y) : "f"(x), "r"(lane_mask)); return y; } /*! * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a * butterfly shuffle between threads in a warp. \param x The value in the source * lane \param lane_mask The mask to perform thread index xor with: y[i] <- x[i * ^ lane_mask] */ __forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) { return __shfl_xor_sync(0xffffffff, x, lane_mask); } /*! * \brief Wrapper of PTX rsqrt approximation instruction, which computes * 1/sqrt(x) \param x input */ __forceinline__ __device__ float rsqrt(float x) { float y; asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); return y; } /*! * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x) * \param x input */ __forceinline__ __device__ float tanh(float x) { float y; asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x)); return y; } /*! * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x) * \param x input */ __forceinline__ __device__ half2 tanh(half2 x) { uint32_t y_u32; uint32_t x_u32 = half2_as_uint32(x); asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32)); return uint32_as_half2(y_u32); } /*! * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x) * \param x input */ __forceinline__ __device__ half tanh(half x) { ushort y_u16; asm volatile("tanh.approx.f16 %0, %1;" : "=h"(y_u16) : "h"(__half_as_ushort(x))); return __ushort_as_half(y_u16); } } // namespace math } // namespace aphrodite #endif // APHRODITE_MATH_CUH_