#pragma once #include "../../../attention/attention_dtypes.h" #include #include #include #include namespace aphrodite { #ifndef USE_ROCM namespace fp8 { #ifdef ENABLE_FP8 #if 0 // Disable the following code to reduce the binary size. template __inline__ __device__ Tout vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) { return x; } // fp8 -> half template <> __inline__ __device__ uint16_t vec_conversion( const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); return res.x; } // fp8x2 -> half2 template <> __inline__ __device__ uint32_t vec_conversion( const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { union { uint16_t u16[2]; uint32_t u32; } tmp; __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); tmp.u16[0] = res.x; tmp.u16[1] = res.y; return tmp.u32; } // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 vec_conversion( const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { union { uint2 u32x2; uint32_t u32[2]; } tmp; tmp.u32[0] = vec_conversion((uint16_t)a, fp8_type); tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U), fp8_type); return tmp.u32x2; } // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 vec_conversion( const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { union { uint4 u64x2; uint2 u64[2]; } tmp; tmp.u64[0] = vec_conversion(a.x, fp8_type); tmp.u64[1] = vec_conversion(a.y, fp8_type); return tmp.u64x2; } // fp8 -> __nv_bfloat16 template <> __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>( const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { // Note there is no direct convert function from fp8 to bf16. // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); // half -> float -> bf16 float tmp = half_to_float(res.x); return __float2bfloat16(tmp); } // fp8x2 -> __nv_bfloat162 template <> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>( const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 res; res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type); res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type); return res; } // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t vec_conversion( const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t res; res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type); res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type); return res; } // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t vec_conversion( const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t tmp1, tmp2; tmp1 = vec_conversion(a.x, fp8_type); tmp2 = vec_conversion(a.y, fp8_type); bf16_8_t res; res.x = tmp1.x; res.y = tmp1.y; res.z = tmp2.x; res.w = tmp2.y; return res; } // fp8 -> float template <> __inline__ __device__ float vec_conversion(const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) { // fp8 -> half uint16_t tmp = vec_conversion(a, fp8_type); // half -> float return half_to_float(tmp); } // fp8x2 -> float2 template <> __inline__ __device__ float2 vec_conversion( const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { // fp8x2 -> half2 uint32_t tmp = vec_conversion(a, fp8_type); // half2 -> float2 return half2_to_float2(tmp); } // fp8x4 -> float4 template <> __inline__ __device__ Float4_ vec_conversion( const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { Float4_ res; res.x = vec_conversion((uint16_t)a, fp8_type); res.y = vec_conversion((uint16_t)(a >> 16U), fp8_type); return res; } // fp8x8 -> float8 template <> __inline__ __device__ Float8_ vec_conversion( const uint2 &a, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp1, tmp2; tmp1 = vec_conversion(a.x, fp8_type); tmp2 = vec_conversion(a.y, fp8_type); Float8_ res; res.x = tmp1.x; res.y = tmp1.y; res.z = tmp2.x; res.w = tmp2.y; return res; } // half -> fp8 template <> __inline__ __device__ uint8_t vec_conversion( const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) { __half_raw tmp; tmp.x = a; __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type); return (uint8_t)res; } // bf16 -> fp8 template <> __inline__ __device__ uint8_t vec_conversion( const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8( __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type); return (uint8_t)res; #endif } // float -> fp8 template <> __inline__ __device__ uint8_t vec_conversion( const float &a, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type); return (uint8_t)res; } // fp8x4 -> float4 template <> __inline__ __device__ float4 vec_conversion( const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp = vec_conversion(a, fp8_type); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } template <> __inline__ __device__ uint32_t vec_conversion( const float2 &a, const __nv_fp8_interpretation_t fp8_type) { union { half2 float16; uint32_t uint32; }; float16 = __float22half2_rn(a); return uint32; } template <> __inline__ __device__ uint2 vec_conversion( const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { uint2 b; float2 val; val.x = a.x.x; val.y = a.x.y; b.x = vec_conversion(val, fp8_type); val.x = a.y.x; val.y = a.y.y; b.y = vec_conversion(val, fp8_type); return b; } template <> __inline__ __device__ float4 vec_conversion( const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { float4 b; b.x = a.x.x; b.y = a.x.y; b.z = a.y.x; b.w = a.y.y; return b; } template <> __inline__ __device__ uint4 vec_conversion( const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { uint4 b; b.x = vec_conversion(a.x, fp8_type); b.y = vec_conversion(a.y, fp8_type); b.z = vec_conversion(a.z, fp8_type); b.w = vec_conversion(a.w, fp8_type); return b; } template <> __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>( const float2 &a, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 b; from_float(b, a); return b; } template <> __inline__ __device__ bf16_4_t vec_conversion( const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t b; from_float(b, a); return b; } template <> __inline__ __device__ bf16_8_t vec_conversion( const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) { bf16_8_t b; from_float(b, a); return b; } #endif /* Scaled and vectorized conversions, for data exchange between high and low precision domains Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * scale => HP */ template __inline__ __device__ Tout scaled_vec_conversion( const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) { return x; } // fp8 -> half template <> __inline__ __device__ uint16_t scaled_vec_conversion( const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type); return float_to_half(half_to_float(tmp.x) * scale); } // fp8x2 -> half2 template <> __inline__ __device__ uint32_t scaled_vec_conversion( const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint16_t u16[2]; uint32_t u32; } tmp; __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type); tmp.u16[0] = float_to_half(half_to_float(res.x) * scale); tmp.u16[1] = float_to_half(half_to_float(res.y) * scale); return tmp.u32; } // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 scaled_vec_conversion( const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint2 u32x2; uint32_t u32[2]; } tmp; tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale, fp8_type); tmp.u32[1] = scaled_vec_conversion((uint16_t)(a >> 16U), scale, fp8_type); return tmp.u32x2; } // fp8x8 -> half2x4 template <> __inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { union { uint4 u64x2; uint2 u64[2]; } tmp; tmp.u64[0] = scaled_vec_conversion(a.x, scale, fp8_type); tmp.u64[1] = scaled_vec_conversion(a.y, scale, fp8_type); return tmp.u64x2; } // fp8 -> __nv_bfloat16 template <> __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>( const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // Note there is no direct convert function from fp8 to bf16. // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); // half -> float -> bf16 float tmp = half_to_float(res.x); return __float2bfloat16(tmp * scale); } // fp8x2 -> __nv_bfloat162 template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>( const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, fp8_type); res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, fp8_type); return res; } // fp8x4 -> bf16_4_t template <> __inline__ __device__ bf16_4_t scaled_vec_conversion( const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, fp8_type); res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale, fp8_type); return res; } // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t scaled_vec_conversion( const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); bf16_8_t res; res.x = tmp1.x; res.y = tmp1.y; res.z = tmp2.x; res.w = tmp2.y; return res; } // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( const uint8_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8 -> half __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type); uint16_t tmp = res.x; // half -> float return half_to_float(tmp) * scale; } // fp8x2 -> float2 template <> __inline__ __device__ float2 scaled_vec_conversion( const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { // fp8x2 -> half2 uint32_t tmp = scaled_vec_conversion(a, scale, fp8_type); // half2 -> float2 return half2_to_float2(tmp); } // fp8x4 -> float4 template <> __inline__ __device__ Float4_ scaled_vec_conversion( const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ res; res.x = scaled_vec_conversion((uint16_t)a, scale, fp8_type); res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale, fp8_type); return res; } // fp8x8 -> float8 template <> __inline__ __device__ Float8_ scaled_vec_conversion( const uint2& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale, fp8_type); tmp2 = scaled_vec_conversion(a.y, scale, fp8_type); Float8_ res; res.x = tmp1.x; res.y = tmp1.y; res.z = tmp2.x; res.w = tmp2.y; return res; } // half -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( const uint16_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; } // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( const __nv_bfloat16& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; #endif } // float -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( const float& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type); return (uint8_t)res; } // fp8x4 -> float4 template <> __inline__ __device__ float4 scaled_vec_conversion( const uint32_t& a, const float scale, const __nv_fp8_interpretation_t fp8_type) { Float4_ tmp = scaled_vec_conversion(a, scale, fp8_type); float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); return res; } #endif // ENABLE_FP8 template __inline__ __device__ Tout convert(const Tin& x) { #if 0 // Disable the following code to reduce the binary size. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return vec_conversion(x, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return vec_conversion(x, __NV_E5M2); } #endif assert(false); return {}; // Squash missing return statement warning } template __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { #ifdef ENABLE_FP8 if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { return scaled_vec_conversion(x, scale, __NV_E4M3); } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) { return scaled_vec_conversion(x, scale, __NV_E5M2); } #endif assert(false); return {}; // Squash missing return statement warning } // The following macro is used to dispatch the conversion function based on // the data type of the key and value cache. The FN is a macro that calls a // function with template. #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ if (KV_DTYPE == "auto") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, float, aphrodite::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, __nv_bfloat16, \ aphrodite::Fp8KVCacheDataType::kAuto); \ } else { \ TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \ } else { \ TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else if (KV_DTYPE == "fp8_e5m2") { \ if (SRC_DTYPE == at::ScalarType::Float) { \ FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \ } else if (SRC_DTYPE == at::ScalarType::Half) { \ FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \ } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \ } else { \ TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ } else { \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ } } // namespace fp8 #endif // not USE_ROCM } // namespace aphrodite