quant_utils.cuh 16 KB


  1. #pragma once
  2. #include "hip_float8.h"
  3. #include <hip/hip_fp16.h>
  4. #include <hip/hip_bf16.h>
  5. #include <hip/hip_bfloat16.h>
  6. #include "../../../attention/dtype_fp8.cuh"
  7. #include "../../../attention/dtype_float32.cuh"
  8. #include "../../../attention/dtype_bfloat16.cuh"
  9. namespace aphrodite
  10. {
  11. #ifdef USE_ROCM
  12. namespace fp8 {
  13. #ifdef ENABLE_FP8
  14. template <typename Tout, typename Tin>
  15. __inline__ __device__ Tout vec_conversion(const Tin& x)
  16. {
  17. return x;
  18. }
  19. template <typename Tout, typename Tin>
  20. __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
  21. {
  22. return x;
  23. }
  24. // fp8 -> half
  25. template <>
  26. __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
  27. {
  28. hip_fp8 f8{a, hip_fp8::from_bits()};
  29. __half_raw res;
  30. res.data = static_cast<float>(f8);
  31. return res.x;
  32. }
  33. // fp8x2 -> half2
  34. template <>
  35. __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
  36. {
  37. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  38. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  39. union {
  40. __half2_raw h2r;
  41. uint32_t ui32;
  42. } tmp;
  43. tmp.h2r.x.data = f2[0];
  44. tmp.h2r.y.data = f2[1];
  45. return tmp.ui32;
  46. #else
  47. union {
  48. uint16_t u16[2];
  49. uint32_t u32;
  50. } tmp;
  51. tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
  52. tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
  53. return tmp.u32;
  54. #endif
  55. }
  56. // fp8x4 -> half2x2
  57. template <>
  58. __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
  59. {
  60. union {
  61. uint2 u32x2;
  62. uint32_t u32[2];
  63. } tmp;
  64. tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
  65. tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
  66. return tmp.u32x2;
  67. }
  68. // fp8x8 -> half2x4
  69. template <>
  70. __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
  71. {
  72. union {
  73. uint4 u64x2;
  74. uint2 u64[2];
  75. } tmp;
  76. tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
  77. tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
  78. return tmp.u64x2;
  79. }
  80. using __nv_bfloat16 = __hip_bfloat16;
  81. // fp8 -> __nv_bfloat16
  82. template <>
  83. __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
  84. {
  85. hip_fp8 f8{a, hip_fp8::from_bits()};
  86. float f{f8};
  87. return __float2bfloat16(f);
  88. }
  89. using __nv_bfloat162 = __hip_bfloat162;
  90. // fp8x2 -> __nv_bfloat162
  91. template <>
  92. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
  93. {
  94. __nv_bfloat162 res;
  95. res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
  96. res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
  97. return res;
  98. }
  99. // fp8x4 -> bf16_4_t
  100. template <>
  101. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
  102. {
  103. bf16_4_t res;
  104. res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
  105. res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
  106. return res;
  107. }
  108. // fp8x8 -> bf16_8_t
  109. template <>
  110. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
  111. {
  112. bf16_4_t tmp1, tmp2;
  113. tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
  114. tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
  115. bf16_8_t res;
  116. res.x = tmp1.x;
  117. res.y = tmp1.y;
  118. res.z = tmp2.x;
  119. res.w = tmp2.y;
  120. return res;
  121. }
  122. // fp8 -> float
  123. template <>
  124. __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
  125. {
  126. hip_fp8 fp8{a, hip_fp8::from_bits()};
  127. return static_cast<float>(fp8);
  128. }
  129. // fp8x2 -> float2
  130. template <>
  131. __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
  132. {
  133. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  134. float2 res;
  135. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  136. res.x = f2[0];
  137. res.y = f2[1];
  138. return res;
  139. #else
  140. float2 res;
  141. res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
  142. res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
  143. return res;
  144. #endif
  145. }
  146. // fp8x4 -> float4
  147. template <>
  148. __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
  149. {
  150. Float4_ res;
  151. res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
  152. res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
  153. return res;
  154. }
  155. // fp8x8 -> float8
  156. template <>
  157. __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
  158. {
  159. Float4_ tmp1, tmp2;
  160. tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
  161. tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
  162. Float8_ res;
  163. res.x = tmp1.x;
  164. res.y = tmp1.y;
  165. res.z = tmp2.x;
  166. res.w = tmp2.y;
  167. return res;
  168. }
  169. // half -> fp8
  170. template <>
  171. __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
  172. {
  173. __half_raw tmp;
  174. tmp.x = a;
  175. hip_fp8 f8{static_cast<float>(tmp.data)};
  176. return f8.data;
  177. }
  178. // bf16 -> fp8
  179. template <>
  180. __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
  181. {
  182. hip_fp8 res{__bfloat162float(a)};
  183. return res.data;
  184. }
  185. // float -> fp8
  186. template <>
  187. __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
  188. {
  189. hip_fp8 f8(a);
  190. return f8.data;
  191. }
  192. // fp8x4 -> float4
  193. template <>
  194. __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
  195. {
  196. Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
  197. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  198. return res;
  199. }
  200. // float2 -> half2
  201. template <>
  202. __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
  203. {
  204. union {
  205. half2 float16;
  206. uint32_t uint32;
  207. };
  208. float16 = __float22half2_rn(a);
  209. return uint32;
  210. }
  211. // Float4 -> half2x2
  212. template <>
  213. __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
  214. {
  215. uint2 b;
  216. float2 val;
  217. val.x = a.x.x;
  218. val.y = a.x.y;
  219. b.x = vec_conversion<uint32_t, float2>(val);
  220. val.x = a.y.x;
  221. val.y = a.y.y;
  222. b.y = vec_conversion<uint32_t, float2>(val);
  223. return b;
  224. }
  225. // Float4 -> float4
  226. template <>
  227. __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
  228. {
  229. float4 b;
  230. b.x = a.x.x;
  231. b.y = a.x.y;
  232. b.z = a.y.x;
  233. b.w = a.y.y;
  234. return b;
  235. }
  236. // Float8 -> half2x4
  237. template <>
  238. __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
  239. {
  240. uint4 b;
  241. b.x = vec_conversion<uint32_t, float2>(a.x);
  242. b.y = vec_conversion<uint32_t, float2>(a.y);
  243. b.z = vec_conversion<uint32_t, float2>(a.z);
  244. b.w = vec_conversion<uint32_t, float2>(a.w);
  245. return b;
  246. }
  247. // float2 -> bfloat162
  248. template <>
  249. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
  250. {
  251. __nv_bfloat162 b = __float22bfloat162_rn(a);
  252. return b;
  253. }
  254. // Float4 -> bfloat162x2
  255. template <>
  256. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
  257. {
  258. bf16_4_t b;
  259. b.x = __float22bfloat162_rn(a.x);
  260. b.y = __float22bfloat162_rn(a.y);
  261. return b;
  262. }
  263. // Float8 -> bfloat162x4
  264. template <>
  265. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
  266. {
  267. bf16_8_t b;
  268. b.x = __float22bfloat162_rn(a.x);
  269. b.y = __float22bfloat162_rn(a.y);
  270. b.z = __float22bfloat162_rn(a.z);
  271. b.w = __float22bfloat162_rn(a.w);
  272. return b;
  273. }
  274. /* Scaled and vectorized conversions, for data exchange between high and low precision domains
  275. Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
  276. s.t.
  277. Quantize(HP / scale) => FP8
  278. Dequant(FP8) * scale => HP
  279. */
  280. // fp8 -> half
  281. template <>
  282. __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
  283. {
  284. hip_fp8 f8{a, hip_fp8::from_bits()};
  285. __half_raw res;
  286. res.data = static_cast<float>(f8) * scale;
  287. return res.x;
  288. }
  289. // fp8x2 -> half2
  290. template <>
  291. __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
  292. {
  293. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  294. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  295. union {
  296. __half2_raw h2r;
  297. uint32_t ui32;
  298. } tmp;
  299. tmp.h2r.x.data = f2[0] * scale;
  300. tmp.h2r.y.data = f2[1] * scale;
  301. return tmp.ui32;
  302. #else
  303. union {
  304. uint16_t u16[2];
  305. uint32_t u32;
  306. } tmp;
  307. tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
  308. tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
  309. return tmp.u32;
  310. #endif
  311. }
  312. // fp8x4 -> half2x2
  313. template <>
  314. __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
  315. {
  316. union {
  317. uint2 u32x2;
  318. uint32_t u32[2];
  319. } tmp;
  320. tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
  321. tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
  322. return tmp.u32x2;
  323. }
  324. // fp8x8 -> half2x4
  325. template <>
  326. __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
  327. {
  328. union {
  329. uint4 u64x2;
  330. uint2 u64[2];
  331. } tmp;
  332. tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  333. tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  334. return tmp.u64x2;
  335. }
  336. using __nv_bfloat16 = __hip_bfloat16;
  337. // fp8 -> __nv_bfloat16
  338. template <>
  339. __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
  340. {
  341. hip_fp8 f8{a, hip_fp8::from_bits()};
  342. float f{f8};
  343. return __float2bfloat16(f * scale);
  344. }
  345. using __nv_bfloat162 = __hip_bfloat162;
  346. // fp8x2 -> __nv_bfloat162
  347. template <>
  348. __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
  349. {
  350. __nv_bfloat162 res;
  351. res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  352. res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  353. return res;
  354. }
  355. // fp8x4 -> bf16_4_t
  356. template <>
  357. __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
  358. {
  359. bf16_4_t res;
  360. res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  361. res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
  362. return res;
  363. }
  364. // fp8x8 -> bf16_8_t
  365. template <>
  366. __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
  367. {
  368. bf16_4_t tmp1, tmp2;
  369. tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  370. tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  371. bf16_8_t res;
  372. res.x = tmp1.x;
  373. res.y = tmp1.y;
  374. res.z = tmp2.x;
  375. res.w = tmp2.y;
  376. return res;
  377. }
  378. // fp8 -> float
  379. template <>
  380. __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
  381. {
  382. hip_fp8 fp8{a, hip_fp8::from_bits()};
  383. return static_cast<float>(fp8) * scale;
  384. }
  385. // fp8x2 -> float2
  386. template <>
  387. __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
  388. {
  389. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  390. float2 res;
  391. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  392. res.x = f2[0] * scale;
  393. res.y = f2[1] * scale;
  394. return res;
  395. #else
  396. float2 res;
  397. res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
  398. res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
  399. return res;
  400. #endif
  401. }
  402. // fp8x4 -> float4
  403. template <>
  404. __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
  405. {
  406. Float4_ res;
  407. res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  408. res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  409. return res;
  410. }
  411. // fp8x8 -> float8
  412. template <>
  413. __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
  414. {
  415. Float4_ tmp1, tmp2;
  416. tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  417. tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  418. Float8_ res;
  419. res.x = tmp1.x;
  420. res.y = tmp1.y;
  421. res.z = tmp2.x;
  422. res.w = tmp2.y;
  423. return res;
  424. }
  425. /* Quantize(HP / scale) => FP8 */
  426. // TODO: vectorized to add
  427. // half -> fp8
  428. template <>
  429. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
  430. {
  431. __half_raw tmp;
  432. tmp.x = a;
  433. hip_fp8 f8{static_cast<float>(tmp.data)/scale};
  434. return f8.data;
  435. }
  436. // bf16 -> fp8
  437. template <>
  438. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
  439. {
  440. hip_fp8 res{__bfloat162float(a)/scale};
  441. return res.data;
  442. }
  443. // float -> fp8
  444. template <>
  445. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
  446. {
  447. hip_fp8 f8(a/scale);
  448. return f8.data;
  449. }
  450. // fp8x4 -> float4
  451. template <>
  452. __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
  453. {
  454. Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  455. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  456. return res;
  457. }
  458. #endif // ENABLE_FP8
  459. template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  460. __inline__ __device__ Tout convert(const Tin &x) {
  461. #ifdef ENABLE_FP8
  462. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
  463. return vec_conversion<Tout, Tin>(x);
  464. }
  465. #endif
  466. assert(false);
  467. return {}; // Squash missing return statement warning
  468. }
  469. template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  470. __inline__ __device__ Tout scaled_convert(const Tin &x, const float scale) {
  471. #ifdef ENABLE_FP8
  472. if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
  473. return scaled_vec_conversion<Tout, Tin>(x, scale);
  474. }
  475. #endif
  476. assert(false);
  477. return {}; // Squash missing return statement warning
  478. }
  479. // The following macro is used to dispatch the conversion function based on the
  480. // data type of the key and value cache. The FN is a macro that calls a function
  481. // with template<typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>.
  482. #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \
  483. if (KV_DTYPE == "auto") { \
  484. if (SRC_DTYPE == at::ScalarType::Float) { \
  485. FN(float, float, aphrodite::Fp8KVCacheDataType::kAuto); \
  486. } else if (SRC_DTYPE == at::ScalarType::Half) { \
  487. FN(uint16_t, uint16_t, aphrodite::Fp8KVCacheDataType::kAuto); \
  488. } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
  489. FN(__nv_bfloat16, __nv_bfloat16, aphrodite::Fp8KVCacheDataType::kAuto); \
  490. } else { \
  491. TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
  492. } \
  493. } else { \
  494. if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \
  495. if (SRC_DTYPE == at::ScalarType::Float) { \
  496. FN(float, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
  497. } else if (SRC_DTYPE == at::ScalarType::Half) { \
  498. FN(uint16_t, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
  499. } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \
  500. FN(__nv_bfloat16, uint8_t, aphrodite::Fp8KVCacheDataType::kFp8E4M3); \
  501. } else { \
  502. TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
  503. } \
  504. } else { \
  505. TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \
  506. } \
  507. }
  508. } // fp8
  509. #endif // USE_ROCM
  510. } // namespace aphrodite