/* * Copyright (c) 2024 by PygmalionAI team. * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ #include #include #include #include #include namespace aphrodite { #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) #define APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED #endif #define APHRODITE_INLINE inline __attribute__((always_inline)) __device__ /******************* vec_t type cast *******************/ template struct vec_cast { template APHRODITE_INLINE static void cast(dst_t* dst, const src_t* src) { #pragma unroll for (size_t i = 0; i < vec_size; ++i) { dst[i] = (dst_t)src[i]; } } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(float* dst, const half* src) { if constexpr (vec_size == 1) { dst[0] = (float)src[0]; } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); } } } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(half* dst, const float* src) { if constexpr (vec_size == 1) { dst[0] = __float2half(src[0]); } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); } } } }; template constexpr APHRODITE_INLINE int get_exponent_bits() { if constexpr (std::is_same::value) { return 4; } else if constexpr (std::is_same::value) { return 5; } else if constexpr (std::is_same::value) { return 5; } else if constexpr (std::is_same::value) { return 8; } } template constexpr APHRODITE_INLINE int get_mantissa_bits() { if constexpr (std::is_same::value) { return 3; } else if constexpr (std::is_same::value) { return 2; } else if constexpr (std::is_same::value) { return 11; } else if constexpr (std::is_same::value) { return 7; } } /*! * \brief Fallback to software fast dequant implementation if hardware * dequantization is not available. \note Inspired by Marlin's fast * dequantization, but here we don't have to permute weights order. \ref * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 */ template __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { uint32_t q = *input; if constexpr (std::is_same::value && std::is_same::value) { output->x = __byte_perm(0U, q, 0x5140); output->y = __byte_perm(0U, q, 0x7362); } else { constexpr int FP8_EXPONENT = get_exponent_bits(); constexpr int FP8_MANTISSA = get_mantissa_bits(); constexpr int FP16_EXPONENT = get_exponent_bits(); constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; // Calculate MASK for extracting mantissa and exponent constexpr int MASK1 = 0x80000000; constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); constexpr int MASK3 = MASK2 & 0x7fffffff; constexpr int MASK = MASK3 | (MASK3 >> 16); q = __byte_perm(q, q, 0x1302); // Extract and shift FP8 values to FP16 format uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); // Construct and apply exponent bias if constexpr (std::is_same::value) { const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); // Convert to half2 and apply bias *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); } else { constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias *(nv_bfloat162*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); *(nv_bfloat162*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); } } } template <> struct vec_cast { template APHRODITE_INLINE static void cast(nv_bfloat16* dst, const __nv_fp8_e4m3* src) { if constexpr (vec_size == 1) { dst[0] = nv_bfloat16(src[0]); } else if constexpr (vec_size == 2) { dst[0] = nv_bfloat16(src[0]); dst[1] = nv_bfloat16(src[1]); } else { static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll for (uint32_t i = 0; i < vec_size / 4; ++i) { fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); } } } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(nv_bfloat16* dst, const __nv_fp8_e5m2* src) { if constexpr (vec_size == 1) { dst[0] = nv_bfloat16(src[0]); } else if constexpr (vec_size == 2) { dst[0] = nv_bfloat16(src[0]); dst[1] = nv_bfloat16(src[1]); } else { static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll for (uint32_t i = 0; i < vec_size / 4; ++i) { fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); } } } }; template <> struct vec_cast<__nv_fp8_e4m3, half> { template APHRODITE_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) { #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = __nv_fp8_e4m3(src[0]); } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { uint16_t y; uint32_t x = *(uint32_t*)&src[i * 2]; asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); *(uint16_t*)&dst[i * 2] = y; } } #else #pragma unroll for (size_t i = 0; i < vec_size; ++i) { dst[i] = __nv_fp8_e4m3(src[i]); } #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED } }; template <> struct vec_cast<__nv_fp8_e5m2, half> { template APHRODITE_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) { #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = __nv_fp8_e5m2(src[0]); } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { uint16_t y; uint32_t x = *(uint32_t*)&src[i * 2]; asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" : "=h"(y) : "r"(x)); *(uint16_t*)&dst[i * 2] = y; } } #else #pragma unroll for (size_t i = 0; i < vec_size; ++i) { dst[i] = __nv_fp8_e5m2(src[i]); } #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) { #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = half(src[0]); } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { uint32_t y; uint16_t x = *(uint16_t*)&src[i * 2]; asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x)); *(uint32_t*)&dst[i * 2] = y; } } #else if constexpr (vec_size == 1) { dst[0] = half(src[0]); } else if constexpr (vec_size == 2) { dst[0] = half(src[0]); dst[1] = half(src[1]); } else { static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll for (uint32_t i = 0; i < vec_size / 4; ++i) { fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); } } #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) { #ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED if constexpr (vec_size == 1) { dst[0] = half(src[0]); } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { uint32_t y; uint16_t x = *(uint16_t*)&src[i * 2]; asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x)); *(uint32_t*)&dst[i * 2] = y; } } #else if constexpr (vec_size == 1) { dst[0] = half(src[0]); } else if constexpr (vec_size == 2) { dst[0] = half(src[0]); dst[1] = half(src[1]); } else { static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll for (uint32_t i = 0; i < vec_size / 4; ++i) { fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4], (uint2*)&dst[i * 4]); } } #endif // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(float* dst, const nv_bfloat16* src) { if constexpr (vec_size == 1) { dst[0] = (float)src[0]; } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); } } } }; template <> struct vec_cast { template APHRODITE_INLINE static void cast(nv_bfloat16* dst, const float* src) { if constexpr (vec_size == 1) { dst[0] = nv_bfloat16(src[0]); } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); } } } }; template struct vec_t { APHRODITE_INLINE float_t& operator[](size_t i); APHRODITE_INLINE const float_t& operator[](size_t i) const; APHRODITE_INLINE void fill(float_t val); APHRODITE_INLINE void load(const float_t* ptr); APHRODITE_INLINE void store(float_t* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src); template APHRODITE_INLINE void cast_load(const T* ptr); template APHRODITE_INLINE void cast_store(T* ptr) const; APHRODITE_INLINE static void memcpy(float_t* dst, const float_t* src); APHRODITE_INLINE float_t* ptr(); }; template APHRODITE_INLINE void cast_from_impl(vec_t& dst, const vec_t& src) { vec_cast::cast( dst.ptr(), const_cast*>(&src)->ptr()); } template APHRODITE_INLINE void cast_load_impl(vec_t& dst, const src_float_t* src_ptr) { if constexpr (std::is_same::value) { dst.load(src_ptr); } else { vec_t tmp; tmp.load(src_ptr); dst.cast_from(tmp); } } template APHRODITE_INLINE void cast_store_impl(tgt_float_t* dst_ptr, const vec_t& src) { if constexpr (std::is_same::value) { src.store(dst_ptr); } else { vec_t tmp; tmp.cast_from(src); tmp.store(dst_ptr); } } /******************* vec_t<__nv_fp8_e4m3> *******************/ // __nv_fp8_e4m3 x 1 template <> struct vec_t<__nv_fp8_e4m3, 1> { __nv_fp8_e4m3 data; APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { return ((const __nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { data = val; } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) { data = *ptr; } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const { *ptr = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { *dst = *src; } // __nv_fp8_e4m3 x 2 template <> struct vec_t<__nv_fp8_e4m3, 2> { __nv_fp8x2_e4m3 data; APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { return ((const __nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { data.__x = (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) { data = *((__nv_fp8x2_e4m3*)ptr); } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const { *((__nv_fp8x2_e4m3*)ptr) = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src); } // __nv_fp8_e4m3 x 4 template <> struct vec_t<__nv_fp8_e4m3, 4> { __nv_fp8x4_e4m3 data; APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { return ((const __nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) { data = *((__nv_fp8x4_e4m3*)ptr); } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const { *((__nv_fp8x4_e4m3*)ptr) = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src); } // __nv_fp8_e4m3 x 8 template <> struct vec_t<__nv_fp8_e4m3, 8> { uint2 data; APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { return ((const __nv_fp8_e4m3*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e4m3 val); APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr); APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) { data = *((uint2*)ptr); } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const { *((uint2*)ptr) = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { *((uint2*)dst) = *((uint2*)src); } // __nv_fp8_e4m3 x 16 or more template struct vec_t<__nv_fp8_e4m3, vec_size> { uint4 data[vec_size / 16]; APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) { return ((__nv_fp8_e4m3*)data)[i]; } APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const { return ((const __nv_fp8_e4m3*)data)[i]; } APHRODITE_INLINE __nv_fp8_e4m3* ptr() { return reinterpret_cast<__nv_fp8_e4m3*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e4m3 val) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); } } APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { data[i] = ((uint4*)ptr)[i]; } } APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { ((uint4*)ptr)[i] = data[i]; } } template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { ((uint4*)dst)[i] = ((uint4*)src)[i]; } } }; /******************* vec_t<__nv_fp8_e5m2> *******************/ // __nv_fp8_e5m2 x 1 template <> struct vec_t<__nv_fp8_e5m2, 1> { __nv_fp8_e5m2 data; APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { return ((const __nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { data = val; } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) { data = *ptr; } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const { *ptr = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { *dst = *src; } // __nv_fp8_e5m2 x 2 template <> struct vec_t<__nv_fp8_e5m2, 2> { __nv_fp8x2_e5m2 data; APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { return ((const __nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { data.__x = (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) { data = *((__nv_fp8x2_e5m2*)ptr); } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const { *((__nv_fp8x2_e5m2*)ptr) = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src); } // __nv_fp8_e5m2 x 4 template <> struct vec_t<__nv_fp8_e5m2, 4> { __nv_fp8x4_e5m2 data; APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { return ((const __nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) { data = *((__nv_fp8x4_e5m2*)ptr); } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const { *((__nv_fp8x4_e5m2*)ptr) = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src); } // __nv_fp8_e5m2 x 8 template <> struct vec_t<__nv_fp8_e5m2, 8> { uint2 data; APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { return ((const __nv_fp8_e5m2*)(&data))[i]; } APHRODITE_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e5m2 val); APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr); APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src); }; APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) { data = *((uint2*)ptr); } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const { *((uint2*)ptr) = data; } APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { *((uint2*)dst) = *((uint2*)src); } // __nv_fp8_e5m2 x 16 or more template struct vec_t<__nv_fp8_e5m2, vec_size> { uint4 data[vec_size / 16]; APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) { return ((__nv_fp8_e5m2*)data)[i]; } APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const { return ((const __nv_fp8_e5m2*)data)[i]; } APHRODITE_INLINE __nv_fp8_e5m2* ptr() { return reinterpret_cast<__nv_fp8_e5m2*>(&data); } APHRODITE_INLINE void fill(__nv_fp8_e5m2 val) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | (__nv_fp8x4_storage_t(val.__x) << 16) | (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); } } APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { data[i] = ((uint4*)ptr)[i]; } } APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { ((uint4*)ptr)[i] = data[i]; } } template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { ((uint4*)dst)[i] = ((uint4*)src)[i]; } } }; /******************* vec_t *******************/ // half x 1 template <> struct vec_t { half data; APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } APHRODITE_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(half val); APHRODITE_INLINE void load(const half* ptr); APHRODITE_INLINE void store(half* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(half* dst, const half* src); }; APHRODITE_INLINE void vec_t::fill(half val) { data = val; } APHRODITE_INLINE void vec_t::load(const half* ptr) { data = *ptr; } APHRODITE_INLINE void vec_t::store(half* ptr) const { *ptr = data; } APHRODITE_INLINE void vec_t::memcpy(half* dst, const half* src) { *dst = *src; } // half x 2 template <> struct vec_t { half2 data; APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } APHRODITE_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(half val); APHRODITE_INLINE void load(const half* ptr); APHRODITE_INLINE void store(half* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(half* dst, const half* src); }; APHRODITE_INLINE void vec_t::fill(half val) { data = make_half2(val, val); } APHRODITE_INLINE void vec_t::load(const half* ptr) { data = *((half2*)ptr); } APHRODITE_INLINE void vec_t::store(half* ptr) const { *((half2*)ptr) = data; } APHRODITE_INLINE void vec_t::memcpy(half* dst, const half* src) { *((half2*)dst) = *((half2*)src); } // half x 4 template <> struct vec_t { uint2 data; APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } APHRODITE_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(half val); APHRODITE_INLINE void load(const half* ptr); APHRODITE_INLINE void store(half* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(half* dst, const half* src); }; APHRODITE_INLINE void vec_t::fill(half val) { *(half2*)(&data.x) = make_half2(val, val); *(half2*)(&data.y) = make_half2(val, val); } APHRODITE_INLINE void vec_t::load(const half* ptr) { data = *((uint2*)ptr); } APHRODITE_INLINE void vec_t::store(half* ptr) const { *((uint2*)ptr) = data; } APHRODITE_INLINE void vec_t::memcpy(half* dst, const half* src) { *((uint2*)dst) = *((uint2*)src); } // half x 8 or more template struct vec_t { uint4 data[vec_size / 8]; APHRODITE_INLINE half& operator[](size_t i) { return ((half*)data)[i]; } APHRODITE_INLINE const half& operator[](size_t i) const { return ((const half*)data)[i]; } APHRODITE_INLINE half* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(half val) { #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { *(half2*)(&(data[i].x)) = make_half2(val, val); *(half2*)(&(data[i].y)) = make_half2(val, val); *(half2*)(&(data[i].z)) = make_half2(val, val); *(half2*)(&(data[i].w)) = make_half2(val, val); } } APHRODITE_INLINE void load(const half* ptr) { #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { data[i] = ((uint4*)ptr)[i]; } } APHRODITE_INLINE void store(half* ptr) const { #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)ptr)[i] = data[i]; } } template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(half* dst, const half* src) { #pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)dst)[i] = ((uint4*)src)[i]; } } }; /******************* vec_t *******************/ // nv_bfloat16 x 1 template <> struct vec_t { nv_bfloat16 data; APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; } APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)(&data))[i]; } APHRODITE_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(nv_bfloat16 val); APHRODITE_INLINE void load(const nv_bfloat16* ptr); APHRODITE_INLINE void store(nv_bfloat16* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src); }; APHRODITE_INLINE void vec_t::fill(nv_bfloat16 val) { data = val; } APHRODITE_INLINE void vec_t::load(const nv_bfloat16* ptr) { data = *ptr; } APHRODITE_INLINE void vec_t::store(nv_bfloat16* ptr) const { *ptr = data; } APHRODITE_INLINE void vec_t::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) { *dst = *src; } // nv_bfloat16 x 2 template <> struct vec_t { nv_bfloat162 data; APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; } APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)(&data))[i]; } APHRODITE_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(nv_bfloat16 val); APHRODITE_INLINE void load(const nv_bfloat16* ptr); APHRODITE_INLINE void store(nv_bfloat16* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src); }; APHRODITE_INLINE void vec_t::fill(nv_bfloat16 val) { data = make_bfloat162(val, val); } APHRODITE_INLINE void vec_t::load(const nv_bfloat16* ptr) { data = *((nv_bfloat162*)ptr); } APHRODITE_INLINE void vec_t::store(nv_bfloat16* ptr) const { *((nv_bfloat162*)ptr) = data; } APHRODITE_INLINE void vec_t::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) { *((nv_bfloat162*)dst) = *((nv_bfloat162*)src); } // nv_bfloat16 x 4 template <> struct vec_t { uint2 data; APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)(&data))[i]; } APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)(&data))[i]; } APHRODITE_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(nv_bfloat16 val); APHRODITE_INLINE void load(const nv_bfloat16* ptr); APHRODITE_INLINE void store(nv_bfloat16* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src); }; APHRODITE_INLINE void vec_t::fill(nv_bfloat16 val) { *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val); *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val); } APHRODITE_INLINE void vec_t::load(const nv_bfloat16* ptr) { data = *((uint2*)ptr); } APHRODITE_INLINE void vec_t::store(nv_bfloat16* ptr) const { *((uint2*)ptr) = data; } APHRODITE_INLINE void vec_t::memcpy(nv_bfloat16* dst, const nv_bfloat16* src) { *((uint2*)dst) = *((uint2*)src); } // nv_bfloat16 x 8 or more template struct vec_t { uint4 data[vec_size / 8]; APHRODITE_INLINE nv_bfloat16& operator[](size_t i) { return ((nv_bfloat16*)data)[i]; } APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const { return ((const nv_bfloat16*)data)[i]; } APHRODITE_INLINE nv_bfloat16* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(nv_bfloat16 val) { #pragma unoll for (size_t i = 0; i < vec_size / 8; ++i) { *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); } } APHRODITE_INLINE void load(const nv_bfloat16* ptr) { #pragma unoll for (size_t i = 0; i < vec_size / 8; ++i) { data[i] = ((uint4*)ptr)[i]; } } APHRODITE_INLINE void store(nv_bfloat16* ptr) const { #pragma unoll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)ptr)[i] = data[i]; } } template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src) { #pragma unoll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4*)dst)[i] = ((uint4*)src)[i]; } } }; /******************* vec_t *******************/ // float x 1 template <> struct vec_t { float data; APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } APHRODITE_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } APHRODITE_INLINE float* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(float val); APHRODITE_INLINE void load(const float* ptr); APHRODITE_INLINE void store(float* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(float* dst, const float* src); }; APHRODITE_INLINE void vec_t::fill(float val) { data = val; } APHRODITE_INLINE void vec_t::load(const float* ptr) { data = *ptr; } APHRODITE_INLINE void vec_t::store(float* ptr) const { *ptr = data; } APHRODITE_INLINE void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } // float x 2 template <> struct vec_t { float2 data; APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } APHRODITE_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } APHRODITE_INLINE float* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(float val); APHRODITE_INLINE void load(const float* ptr); APHRODITE_INLINE void store(float* ptr) const; template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(float* dst, const float* src); }; APHRODITE_INLINE void vec_t::fill(float val) { data = make_float2(val, val); } APHRODITE_INLINE void vec_t::load(const float* ptr) { data = *((float2*)ptr); } APHRODITE_INLINE void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } APHRODITE_INLINE void vec_t::memcpy(float* dst, const float* src) { *((float2*)dst) = *((float2*)src); } // float x 4 or more template struct vec_t { float4 data[vec_size / 4]; APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } APHRODITE_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } APHRODITE_INLINE float* ptr() { return reinterpret_cast(&data); } APHRODITE_INLINE void fill(float val) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { data[i] = make_float4(val, val, val, val); } } APHRODITE_INLINE void load(const float* ptr) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { data[i] = ((float4*)ptr)[i]; } } APHRODITE_INLINE void store(float* ptr) const { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { ((float4*)ptr)[i] = data[i]; } } template APHRODITE_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); } template APHRODITE_INLINE void cast_load(const T* ptr) { cast_load_impl(*this, ptr); } template APHRODITE_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } APHRODITE_INLINE static void memcpy(float* dst, const float* src) { #pragma unroll for (size_t i = 0; i < vec_size / 4; ++i) { ((float4*)dst)[i] = ((float4*)src)[i]; } } }; } // namespace aphrodite #endif // VEC_DTYPES_CUH_