123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573 |
- #pragma once
- #include "../../../attention/attention_dtypes.h"
- #include <assert.h>
- #include <float.h>
- #include <stdint.h>
- #include <type_traits>
- namespace aphrodite {
- #ifndef USE_ROCM
- namespace fp8 {
- #ifdef ENABLE_FP8
- #if 0 // Disable the following code to reduce the binary size.
- template <typename Tout, typename Tin>
- __inline__ __device__ Tout
- vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
- return x;
- }
- // fp8 -> half
- template <>
- __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
- const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
- __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
- return res.x;
- }
- // fp8x2 -> half2
- template <>
- __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
- const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
- union {
- uint16_t u16[2];
- uint32_t u32;
- } tmp;
- __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
- tmp.u16[0] = res.x;
- tmp.u16[1] = res.y;
- return tmp.u32;
- }
- // fp8x4 -> half2x2
- template <>
- __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
- const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
- union {
- uint2 u32x2;
- uint32_t u32[2];
- } tmp;
- tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
- tmp.u32[1] =
- vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
- return tmp.u32x2;
- }
- // fp8x8 -> half2x4
- template <>
- __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
- const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
- union {
- uint4 u64x2;
- uint2 u64[2];
- } tmp;
- tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
- tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
- return tmp.u64x2;
- }
- // fp8 -> __nv_bfloat16
- template <>
- __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
- const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
- // Note there is no direct convert function from fp8 to bf16.
- // fp8 -> half
- __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
- // half -> float -> bf16
- float tmp = half_to_float(res.x);
- return __float2bfloat16(tmp);
- }
- // fp8x2 -> __nv_bfloat162
- template <>
- __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
- const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
- __nv_bfloat162 res;
- res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
- res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
- return res;
- }
- // fp8x4 -> bf16_4_t
- template <>
- __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
- const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
- bf16_4_t res;
- res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
- res.y =
- vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
- return res;
- }
- // fp8x8 -> bf16_8_t
- template <>
- __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
- const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
- bf16_4_t tmp1, tmp2;
- tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
- tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
- bf16_8_t res;
- res.x = tmp1.x;
- res.y = tmp1.y;
- res.z = tmp2.x;
- res.w = tmp2.y;
- return res;
- }
- // fp8 -> float
- template <>
- __inline__ __device__ float
- vec_conversion<float, uint8_t>(const uint8_t &a,
- const __nv_fp8_interpretation_t fp8_type) {
- // fp8 -> half
- uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
- // half -> float
- return half_to_float(tmp);
- }
- // fp8x2 -> float2
- template <>
- __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
- const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
- // fp8x2 -> half2
- uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
- // half2 -> float2
- return half2_to_float2(tmp);
- }
- // fp8x4 -> float4
- template <>
- __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
- const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
- Float4_ res;
- res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
- res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
- return res;
- }
- // fp8x8 -> float8
- template <>
- __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
- const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
- Float4_ tmp1, tmp2;
- tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
- tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
- Float8_ res;
- res.x = tmp1.x;
- res.y = tmp1.y;
- res.z = tmp2.x;
- res.w = tmp2.y;
- return res;
- }
- // half -> fp8
- template <>
- __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
- const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
- __half_raw tmp;
- tmp.x = a;
- __nv_fp8_storage_t res =
- __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
- return (uint8_t)res;
- }
- // bf16 -> fp8
- template <>
- __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
- const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
- assert(false);
- #else
- __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
- __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
- return (uint8_t)res;
- #endif
- }
- // float -> fp8
- template <>
- __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
- const float &a, const __nv_fp8_interpretation_t fp8_type) {
- __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
- return (uint8_t)res;
- }
- // fp8x4 -> float4
- template <>
- __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
- const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
- Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
- float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
- return res;
- }
- template <>
- __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
- const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
- union {
- half2 float16;
- uint32_t uint32;
- };
- float16 = __float22half2_rn(a);
- return uint32;
- }
- template <>
- __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
- const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
- uint2 b;
- float2 val;
- val.x = a.x.x;
- val.y = a.x.y;
- b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
- val.x = a.y.x;
- val.y = a.y.y;
- b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
- return b;
- }
- template <>
- __inline__ __device__ float4 vec_conversion<float4, Float4_>(
- const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
- float4 b;
- b.x = a.x.x;
- b.y = a.x.y;
- b.z = a.y.x;
- b.w = a.y.y;
- return b;
- }
- template <>
- __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
- const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
- uint4 b;
- b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
- b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
- b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
- b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
- return b;
- }
- template <>
- __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
- const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
- __nv_bfloat162 b;
- from_float(b, a);
- return b;
- }
- template <>
- __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
- const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
- bf16_4_t b;
- from_float(b, a);
- return b;
- }
- template <>
- __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
- const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
- bf16_8_t b;
- from_float(b, a);
- return b;
- }
- #endif
- /* Scaled and vectorized conversions, for data exchange between high and low
- precision domains Convention of the scale in API, e.g: FP8_data =
- Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
- Dequant(FP8) * scale => HP
- */
- template <typename Tout, typename Tin>
- __inline__ __device__ Tout scaled_vec_conversion(
- const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
- return x;
- }
- // fp8 -> half
- template <>
- __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
- const uint8_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
- return float_to_half(half_to_float(tmp.x) * scale);
- }
- // fp8x2 -> half2
- template <>
- __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
- const uint16_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- union {
- uint16_t u16[2];
- uint32_t u32;
- } tmp;
- __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
- tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
- tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
- return tmp.u32;
- }
- // fp8x4 -> half2x2
- template <>
- __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
- const uint32_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- union {
- uint2 u32x2;
- uint32_t u32[2];
- } tmp;
- tmp.u32[0] =
- scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
- tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
- scale, fp8_type);
- return tmp.u32x2;
- }
- // fp8x8 -> half2x4
- template <>
- __inline__ __device__ uint4
- scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- union {
- uint4 u64x2;
- uint2 u64[2];
- } tmp;
- tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
- tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
- return tmp.u64x2;
- }
- // fp8 -> __nv_bfloat16
- template <>
- __inline__ __device__ __nv_bfloat16
- scaled_vec_conversion<__nv_bfloat16, uint8_t>(
- const uint8_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- // Note there is no direct convert function from fp8 to bf16.
- // fp8 -> half
- __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
- // half -> float -> bf16
- float tmp = half_to_float(res.x);
- return __float2bfloat16(tmp * scale);
- }
- // fp8x2 -> __nv_bfloat162
- template <>
- __inline__ __device__ __nv_bfloat162
- scaled_vec_conversion<__nv_bfloat162, uint16_t>(
- const uint16_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- __nv_bfloat162 res;
- res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
- fp8_type);
- res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
- scale, fp8_type);
- return res;
- }
- // fp8x4 -> bf16_4_t
- template <>
- __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
- const uint32_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- bf16_4_t res;
- res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
- fp8_type);
- res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
- scale, fp8_type);
- return res;
- }
- // fp8x8 -> bf16_8_t
- template <>
- __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
- const uint2& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- bf16_4_t tmp1, tmp2;
- tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
- tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
- bf16_8_t res;
- res.x = tmp1.x;
- res.y = tmp1.y;
- res.z = tmp2.x;
- res.w = tmp2.y;
- return res;
- }
- // fp8 -> float
- template <>
- __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
- const uint8_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- // fp8 -> half
- __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
- uint16_t tmp = res.x;
- // half -> float
- return half_to_float(tmp) * scale;
- }
- // fp8x2 -> float2
- template <>
- __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
- const uint16_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- // fp8x2 -> half2
- uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
- // half2 -> float2
- return half2_to_float2(tmp);
- }
- // fp8x4 -> float4
- template <>
- __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
- const uint32_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- Float4_ res;
- res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
- res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
- fp8_type);
- return res;
- }
- // fp8x8 -> float8
- template <>
- __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
- const uint2& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- Float4_ tmp1, tmp2;
- tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
- tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
- Float8_ res;
- res.x = tmp1.x;
- res.y = tmp1.y;
- res.z = tmp2.x;
- res.w = tmp2.y;
- return res;
- }
- // half -> fp8
- template <>
- __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
- const uint16_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- __nv_fp8_storage_t res =
- __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
- return (uint8_t)res;
- }
- // bf16 -> fp8
- template <>
- __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
- const __nv_bfloat16& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
- assert(false);
- #else
- __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
- __NV_SATFINITE, fp8_type);
- return (uint8_t)res;
- #endif
- }
- // float -> fp8
- template <>
- __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
- const float& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- __nv_fp8_storage_t res =
- __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
- return (uint8_t)res;
- }
- // fp8x4 -> float4
- template <>
- __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
- const uint32_t& a, const float scale,
- const __nv_fp8_interpretation_t fp8_type) {
- Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
- float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
- return res;
- }
- #endif // ENABLE_FP8
- template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
- __inline__ __device__ Tout convert(const Tin& x) {
- #if 0 // Disable the following code to reduce the binary size.
- if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
- return vec_conversion<Tout, Tin>(x, __NV_E4M3);
- } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
- return vec_conversion<Tout, Tin>(x, __NV_E5M2);
- }
- #endif
- assert(false);
- return {}; // Squash missing return statement warning
- }
- template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
- __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
- #ifdef ENABLE_FP8
- if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
- return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
- } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
- return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
- }
- #endif
- assert(false);
- return {}; // Squash missing return statement warning
- }
- // The following macro is used to dispatch the conversion function based on
- // the data type of the key and value cache. The FN is a macro that calls a
- // function with template<typename scalar_t, typename cache_t,
- // Fp8KVCacheDataType kv_dt>.
- #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
- if (KV_DTYPE == "auto") { \
- if (SRC_DTYPE == at::ScalarType::Float) { \
- FN(float, float, aphrodite::Fp8KVCacheDataType::kAuto); \
- } else if (SRC_DTYPE == at::ScalarType::Half) { \
- FN(uint16_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto); \
- } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
- FN(__nv_bfloat16, __nv_bfloat16, \
- aphrodite::Fp8KVCacheDataType::kAuto); \
- } else { \
- TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
- } \
- } else { \
- if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
- if (SRC_DTYPE == at::ScalarType::Float) { \
- FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
- } else if (SRC_DTYPE == at::ScalarType::Half) { \
- FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
- } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
- FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
- } else { \
- TORCH_CHECK(false, \
- "Unsupported input type of kv cache: ", SRC_DTYPE); \
- } \
- } else if (KV_DTYPE == "fp8_e5m2") { \
- if (SRC_DTYPE == at::ScalarType::Float) { \
- FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \
- } else if (SRC_DTYPE == at::ScalarType::Half) { \
- FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \
- } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
- FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \
- } else { \
- TORCH_CHECK(false, \
- "Unsupported input type of kv cache: ", SRC_DTYPE); \
- } \
- } else { \
- TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
- } \
- }
- } // namespace fp8
- #endif // not USE_ROCM
- } // namespace aphrodite
|