123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159 |
- /*
- * Copyright (c) 2024 by PygmalionAI team.
- * Copyright (c) 2023 by FlashInfer team.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef APHRODITE_MATH_CUH_
- #define APHRODITE_MATH_CUH_
- #include <cuda_fp16.h>
- #include <cuda_runtime.h>
- namespace aphrodite {
- namespace math {
- // log2(e)
- constexpr float log2e = 1.44269504088896340736f;
- __forceinline__ __device__ half2 uint32_as_half2(uint32_t x) {
- return *(half2*)&x;
- }
- __forceinline__ __device__ uint32_t half2_as_uint32(half2 x) {
- return *(uint32_t*)&x;
- }
- /*!
- * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x
- * \param x input
- */
- __forceinline__ __device__ float ptx_exp2(float x) {
- float y;
- asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
- return y;
- }
- /*!
- * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x)
- * \param x input
- */
- __forceinline__ __device__ float ptx_log2(float x) {
- float y;
- asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
- return y;
- }
- /*!
- * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x
- * \param x input
- */
- __forceinline__ __device__ half2 ptx_exp2(half2 x) {
- uint32_t y_u32;
- uint32_t x_u32 = half2_as_uint32(x);
- asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
- return uint32_as_half2(y_u32);
- }
- /*!
- * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x
- * \param x input
- */
- __forceinline__ __device__ half ptx_exp2(half x) {
- ushort y_u16;
- asm volatile("ex2.approx.f16 %0, %1;"
- : "=h"(y_u16)
- : "h"(__half_as_ushort(x)));
- return __ushort_as_half(y_u16);
- }
- /*!
- * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x
- * \param x input
- */
- __forceinline__ __device__ float ptx_rcp(float x) {
- float y;
- asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
- return y;
- }
- /*!
- * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly
- * shuffle between threads in a warp. \param x The value in the source lane
- * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^
- * delta]
- */
- __forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) {
- float y;
- asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;"
- : "=f"(y)
- : "f"(x), "r"(lane_mask));
- return y;
- }
- /*!
- * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a
- * butterfly shuffle between threads in a warp. \param x The value in the source
- * lane \param lane_mask The mask to perform thread index xor with: y[i] <- x[i
- * ^ lane_mask]
- */
- __forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) {
- return __shfl_xor_sync(0xffffffff, x, lane_mask);
- }
- /*!
- * \brief Wrapper of PTX rsqrt approximation instruction, which computes
- * 1/sqrt(x) \param x input
- */
- __forceinline__ __device__ float rsqrt(float x) {
- float y;
- asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
- return y;
- }
- /*!
- * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x)
- * \param x input
- */
- __forceinline__ __device__ float tanh(float x) {
- float y;
- asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
- return y;
- }
- /*!
- * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x)
- * \param x input
- */
- __forceinline__ __device__ half2 tanh(half2 x) {
- uint32_t y_u32;
- uint32_t x_u32 = half2_as_uint32(x);
- asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
- return uint32_as_half2(y_u32);
- }
- /*!
- * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x)
- * \param x input
- */
- __forceinline__ __device__ half tanh(half x) {
- ushort y_u16;
- asm volatile("tanh.approx.f16 %0, %1;"
- : "=h"(y_u16)
- : "h"(__half_as_ushort(x)));
- return __ushort_as_half(y_u16);
- }
- } // namespace math
- } // namespace aphrodite
- #endif // APHRODITE_MATH_CUH_
|