quant_utils.cuh 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. #pragma once
  2. #include "../../../attention/attention_dtypes.h"
  3. #include <assert.h>
  4. #include <float.h>
  5. #include <stdint.h>
  6. #include <type_traits>
  7. namespace aphrodite {
  8. #ifndef USE_ROCM
  9. namespace fp8 {
  10. #ifdef ENABLE_FP8
  11. #if 0 // Disable the following code to reduce the binary size.
  12. template <typename Tout, typename Tin>
  13. __inline__ __device__ Tout
  14. vec_conversion(const Tin &x, const __nv_fp8_interpretation_t fp8_type) {
  15. return x;
  16. }
  17. // fp8 -> half
  18. template <>
  19. __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
  20. const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
  21. __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  22. return res.x;
  23. }
  24. // fp8x2 -> half2
  25. template <>
  26. __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
  27. const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  28. union {
  29. uint16_t u16[2];
  30. uint32_t u32;
  31. } tmp;
  32. __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
  33. tmp.u16[0] = res.x;
  34. tmp.u16[1] = res.y;
  35. return tmp.u32;
  36. }
  37. // fp8x4 -> half2x2
  38. template <>
  39. __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
  40. const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  41. union {
  42. uint2 u32x2;
  43. uint32_t u32[2];
  44. } tmp;
  45. tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
  46. tmp.u32[1] =
  47. vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
  48. return tmp.u32x2;
  49. }
  50. // fp8x8 -> half2x4
  51. template <>
  52. __inline__ __device__ uint4 vec_conversion<uint4, uint2>(
  53. const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
  54. union {
  55. uint4 u64x2;
  56. uint2 u64[2];
  57. } tmp;
  58. tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
  59. tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
  60. return tmp.u64x2;
  61. }
  62. // fp8 -> __nv_bfloat16
  63. template <>
  64. __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
  65. const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
  66. // Note there is no direct convert function from fp8 to bf16.
  67. // fp8 -> half
  68. __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  69. // half -> float -> bf16
  70. float tmp = half_to_float(res.x);
  71. return __float2bfloat16(tmp);
  72. }
  73. // fp8x2 -> __nv_bfloat162
  74. template <>
  75. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
  76. const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  77. __nv_bfloat162 res;
  78. res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
  79. res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
  80. return res;
  81. }
  82. // fp8x4 -> bf16_4_t
  83. template <>
  84. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
  85. const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  86. bf16_4_t res;
  87. res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
  88. res.y =
  89. vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
  90. return res;
  91. }
  92. // fp8x8 -> bf16_8_t
  93. template <>
  94. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
  95. const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
  96. bf16_4_t tmp1, tmp2;
  97. tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
  98. tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
  99. bf16_8_t res;
  100. res.x = tmp1.x;
  101. res.y = tmp1.y;
  102. res.z = tmp2.x;
  103. res.w = tmp2.y;
  104. return res;
  105. }
  106. // fp8 -> float
  107. template <>
  108. __inline__ __device__ float
  109. vec_conversion<float, uint8_t>(const uint8_t &a,
  110. const __nv_fp8_interpretation_t fp8_type) {
  111. // fp8 -> half
  112. uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
  113. // half -> float
  114. return half_to_float(tmp);
  115. }
  116. // fp8x2 -> float2
  117. template <>
  118. __inline__ __device__ float2 vec_conversion<float2, uint16_t>(
  119. const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  120. // fp8x2 -> half2
  121. uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
  122. // half2 -> float2
  123. return half2_to_float2(tmp);
  124. }
  125. // fp8x4 -> float4
  126. template <>
  127. __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
  128. const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  129. Float4_ res;
  130. res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
  131. res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
  132. return res;
  133. }
  134. // fp8x8 -> float8
  135. template <>
  136. __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
  137. const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
  138. Float4_ tmp1, tmp2;
  139. tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
  140. tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
  141. Float8_ res;
  142. res.x = tmp1.x;
  143. res.y = tmp1.y;
  144. res.z = tmp2.x;
  145. res.w = tmp2.y;
  146. return res;
  147. }
  148. // half -> fp8
  149. template <>
  150. __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
  151. const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  152. __half_raw tmp;
  153. tmp.x = a;
  154. __nv_fp8_storage_t res =
  155. __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
  156. return (uint8_t)res;
  157. }
  158. // bf16 -> fp8
  159. template <>
  160. __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
  161. const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
  162. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  163. assert(false);
  164. #else
  165. __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
  166. __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
  167. return (uint8_t)res;
  168. #endif
  169. }
  170. // float -> fp8
  171. template <>
  172. __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
  173. const float &a, const __nv_fp8_interpretation_t fp8_type) {
  174. __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
  175. return (uint8_t)res;
  176. }
  177. // fp8x4 -> float4
  178. template <>
  179. __inline__ __device__ float4 vec_conversion<float4, uint32_t>(
  180. const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  181. Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
  182. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  183. return res;
  184. }
  185. template <>
  186. __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
  187. const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
  188. union {
  189. half2 float16;
  190. uint32_t uint32;
  191. };
  192. float16 = __float22half2_rn(a);
  193. return uint32;
  194. }
  195. template <>
  196. __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
  197. const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
  198. uint2 b;
  199. float2 val;
  200. val.x = a.x.x;
  201. val.y = a.x.y;
  202. b.x = vec_conversion<uint32_t, float2>(val, fp8_type);
  203. val.x = a.y.x;
  204. val.y = a.y.y;
  205. b.y = vec_conversion<uint32_t, float2>(val, fp8_type);
  206. return b;
  207. }
  208. template <>
  209. __inline__ __device__ float4 vec_conversion<float4, Float4_>(
  210. const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
  211. float4 b;
  212. b.x = a.x.x;
  213. b.y = a.x.y;
  214. b.z = a.y.x;
  215. b.w = a.y.y;
  216. return b;
  217. }
  218. template <>
  219. __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
  220. const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
  221. uint4 b;
  222. b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
  223. b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
  224. b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
  225. b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
  226. return b;
  227. }
  228. template <>
  229. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
  230. const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
  231. __nv_bfloat162 b;
  232. from_float(b, a);
  233. return b;
  234. }
  235. template <>
  236. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
  237. const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
  238. bf16_4_t b;
  239. from_float(b, a);
  240. return b;
  241. }
  242. template <>
  243. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
  244. const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
  245. bf16_8_t b;
  246. from_float(b, a);
  247. return b;
  248. }
  249. #endif
  250. /* Scaled and vectorized conversions, for data exchange between high and low
  251. precision domains Convention of the scale in API, e.g: FP8_data =
  252. Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
  253. Dequant(FP8) * scale => HP
  254. */
  255. template <typename Tout, typename Tin>
  256. __inline__ __device__ Tout scaled_vec_conversion(
  257. const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
  258. return x;
  259. }
  260. // fp8 -> half
  261. template <>
  262. __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
  263. const uint8_t& a, const float scale,
  264. const __nv_fp8_interpretation_t fp8_type) {
  265. __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  266. return float_to_half(half_to_float(tmp.x) * scale);
  267. }
  268. // fp8x2 -> half2
  269. template <>
  270. __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
  271. const uint16_t& a, const float scale,
  272. const __nv_fp8_interpretation_t fp8_type) {
  273. union {
  274. uint16_t u16[2];
  275. uint32_t u32;
  276. } tmp;
  277. __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
  278. tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
  279. tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
  280. return tmp.u32;
  281. }
  282. // fp8x4 -> half2x2
  283. template <>
  284. __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
  285. const uint32_t& a, const float scale,
  286. const __nv_fp8_interpretation_t fp8_type) {
  287. union {
  288. uint2 u32x2;
  289. uint32_t u32[2];
  290. } tmp;
  291. tmp.u32[0] =
  292. scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
  293. tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
  294. scale, fp8_type);
  295. return tmp.u32x2;
  296. }
  297. // fp8x8 -> half2x4
  298. template <>
  299. __inline__ __device__ uint4
  300. scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
  301. const __nv_fp8_interpretation_t fp8_type) {
  302. union {
  303. uint4 u64x2;
  304. uint2 u64[2];
  305. } tmp;
  306. tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
  307. tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
  308. return tmp.u64x2;
  309. }
  310. // fp8 -> __nv_bfloat16
  311. template <>
  312. __inline__ __device__ __nv_bfloat16
  313. scaled_vec_conversion<__nv_bfloat16, uint8_t>(
  314. const uint8_t& a, const float scale,
  315. const __nv_fp8_interpretation_t fp8_type) {
  316. // Note there is no direct convert function from fp8 to bf16.
  317. // fp8 -> half
  318. __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  319. // half -> float -> bf16
  320. float tmp = half_to_float(res.x);
  321. return __float2bfloat16(tmp * scale);
  322. }
  323. // fp8x2 -> __nv_bfloat162
  324. template <>
  325. __inline__ __device__ __nv_bfloat162
  326. scaled_vec_conversion<__nv_bfloat162, uint16_t>(
  327. const uint16_t& a, const float scale,
  328. const __nv_fp8_interpretation_t fp8_type) {
  329. __nv_bfloat162 res;
  330. res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
  331. fp8_type);
  332. res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
  333. scale, fp8_type);
  334. return res;
  335. }
  336. // fp8x4 -> bf16_4_t
  337. template <>
  338. __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
  339. const uint32_t& a, const float scale,
  340. const __nv_fp8_interpretation_t fp8_type) {
  341. bf16_4_t res;
  342. res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
  343. fp8_type);
  344. res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
  345. scale, fp8_type);
  346. return res;
  347. }
  348. // fp8x8 -> bf16_8_t
  349. template <>
  350. __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
  351. const uint2& a, const float scale,
  352. const __nv_fp8_interpretation_t fp8_type) {
  353. bf16_4_t tmp1, tmp2;
  354. tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
  355. tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
  356. bf16_8_t res;
  357. res.x = tmp1.x;
  358. res.y = tmp1.y;
  359. res.z = tmp2.x;
  360. res.w = tmp2.y;
  361. return res;
  362. }
  363. // fp8 -> float
  364. template <>
  365. __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
  366. const uint8_t& a, const float scale,
  367. const __nv_fp8_interpretation_t fp8_type) {
  368. // fp8 -> half
  369. __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  370. uint16_t tmp = res.x;
  371. // half -> float
  372. return half_to_float(tmp) * scale;
  373. }
  374. // fp8x2 -> float2
  375. template <>
  376. __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
  377. const uint16_t& a, const float scale,
  378. const __nv_fp8_interpretation_t fp8_type) {
  379. // fp8x2 -> half2
  380. uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
  381. // half2 -> float2
  382. return half2_to_float2(tmp);
  383. }
  384. // fp8x4 -> float4
  385. template <>
  386. __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
  387. const uint32_t& a, const float scale,
  388. const __nv_fp8_interpretation_t fp8_type) {
  389. Float4_ res;
  390. res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
  391. res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
  392. fp8_type);
  393. return res;
  394. }
  395. // fp8x8 -> float8
  396. template <>
  397. __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
  398. const uint2& a, const float scale,
  399. const __nv_fp8_interpretation_t fp8_type) {
  400. Float4_ tmp1, tmp2;
  401. tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
  402. tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
  403. Float8_ res;
  404. res.x = tmp1.x;
  405. res.y = tmp1.y;
  406. res.z = tmp2.x;
  407. res.w = tmp2.y;
  408. return res;
  409. }
  410. // half -> fp8
  411. template <>
  412. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
  413. const uint16_t& a, const float scale,
  414. const __nv_fp8_interpretation_t fp8_type) {
  415. __nv_fp8_storage_t res =
  416. __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
  417. return (uint8_t)res;
  418. }
  419. // bf16 -> fp8
  420. template <>
  421. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
  422. const __nv_bfloat16& a, const float scale,
  423. const __nv_fp8_interpretation_t fp8_type) {
  424. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  425. assert(false);
  426. #else
  427. __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
  428. __NV_SATFINITE, fp8_type);
  429. return (uint8_t)res;
  430. #endif
  431. }
  432. // float -> fp8
  433. template <>
  434. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
  435. const float& a, const float scale,
  436. const __nv_fp8_interpretation_t fp8_type) {
  437. __nv_fp8_storage_t res =
  438. __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
  439. return (uint8_t)res;
  440. }
  441. // fp8x4 -> float4
  442. template <>
  443. __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
  444. const uint32_t& a, const float scale,
  445. const __nv_fp8_interpretation_t fp8_type) {
  446. Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
  447. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  448. return res;
  449. }
  450. #endif // ENABLE_FP8
  451. template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  452. __inline__ __device__ Tout convert(const Tin& x) {
  453. #if 0 // Disable the following code to reduce the binary size.
  454. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
  455. return vec_conversion<Tout, Tin>(x, __NV_E4M3);
  456. } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
  457. return vec_conversion<Tout, Tin>(x, __NV_E5M2);
  458. }
  459. #endif
  460. assert(false);
  461. return {}; // Squash missing return statement warning
  462. }
  463. template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  464. __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
  465. #ifdef ENABLE_FP8
  466. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
  467. return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
  468. } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
  469. return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
  470. }
  471. #endif
  472. assert(false);
  473. return {}; // Squash missing return statement warning
  474. }
  475. // The following macro is used to dispatch the conversion function based on
  476. // the data type of the key and value cache. The FN is a macro that calls a
  477. // function with template<typename scalar_t, typename cache_t,
  478. // Fp8KVCacheDataType kv_dt>.
  479. #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
  480. if (KV_DTYPE == "auto") { \
  481. if (SRC_DTYPE == at::ScalarType::Float) { \
  482. FN(float, float, aphrodite::Fp8KVCacheDataType::kAuto); \
  483. } else if (SRC_DTYPE == at::ScalarType::Half) { \
  484. FN(uint16_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto); \
  485. } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
  486. FN(__nv_bfloat16, __nv_bfloat16, \
  487. aphrodite::Fp8KVCacheDataType::kAuto); \
  488. } else { \
  489. TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
  490. } \
  491. } else { \
  492. if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
  493. if (SRC_DTYPE == at::ScalarType::Float) { \
  494. FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
  495. } else if (SRC_DTYPE == at::ScalarType::Half) { \
  496. FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
  497. } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
  498. FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
  499. } else { \
  500. TORCH_CHECK(false, \
  501. "Unsupported input type of kv cache: ", SRC_DTYPE); \
  502. } \
  503. } else if (KV_DTYPE == "fp8_e5m2") { \
  504. if (SRC_DTYPE == at::ScalarType::Float) { \
  505. FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \
  506. } else if (SRC_DTYPE == at::ScalarType::Half) { \
  507. FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \
  508. } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
  509. FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E5M2); \
  510. } else { \
  511. TORCH_CHECK(false, \
  512. "Unsupported input type of kv cache: ", SRC_DTYPE); \
  513. } \
  514. } else { \
  515. TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
  516. } \
  517. }
  518. } // namespace fp8
  519. #endif // not USE_ROCM
  520. } // namespace aphrodite