quant_utils.cuh 13 KB


  1. #pragma once
  2. #include "hip_float8.h"
  3. #include <hip/hip_fp16.h>
  4. #include <hip/hip_bf16.h>
  5. #include <hip/hip_bfloat16.h>
  6. #include "../../../attention/dtype_float32.cuh"
  7. #include "../../../attention/dtype_bfloat16.cuh"
  8. namespace aphrodite
  9. {
  10. namespace fp8_e4m3 {
  11. template <typename Tout, typename Tin>
  12. __inline__ __device__ Tout vec_conversion(const Tin& x)
  13. {
  14. return x;
  15. }
  16. template <typename Tout, typename Tin>
  17. __inline__ __device__ Tout scaled_vec_conversion(const Tin& x, const float scale)
  18. {
  19. return x;
  20. }
  21. // fp8 -> half
  22. template <>
  23. __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
  24. {
  25. hip_fp8 f8{a, hip_fp8::from_bits()};
  26. __half_raw res;
  27. res.data = static_cast<float>(f8);
  28. return res.x;
  29. }
  30. // fp8x2 -> half2
  31. template <>
  32. __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
  33. {
  34. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  35. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  36. union {
  37. __half2_raw h2r;
  38. uint32_t ui32;
  39. } tmp;
  40. tmp.h2r.x.data = f2[0];
  41. tmp.h2r.y.data = f2[1];
  42. return tmp.ui32;
  43. #else
  44. union {
  45. uint16_t u16[2];
  46. uint32_t u32;
  47. } tmp;
  48. tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
  49. tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
  50. return tmp.u32;
  51. #endif
  52. }
  53. // fp8x4 -> half2x2
  54. template <>
  55. __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
  56. {
  57. union {
  58. uint2 u32x2;
  59. uint32_t u32[2];
  60. } tmp;
  61. tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
  62. tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
  63. return tmp.u32x2;
  64. }
  65. // fp8x8 -> half2x4
  66. template <>
  67. __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
  68. {
  69. union {
  70. uint4 u64x2;
  71. uint2 u64[2];
  72. } tmp;
  73. tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
  74. tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
  75. return tmp.u64x2;
  76. }
  77. using __nv_bfloat16 = __hip_bfloat16;
  78. // fp8 -> __nv_bfloat16
  79. template <>
  80. __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
  81. {
  82. hip_fp8 f8{a, hip_fp8::from_bits()};
  83. float f{f8};
  84. return __float2bfloat16(f);
  85. }
  86. using __nv_bfloat162 = __hip_bfloat162;
  87. // fp8x2 -> __nv_bfloat162
  88. template <>
  89. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
  90. {
  91. __nv_bfloat162 res;
  92. res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
  93. res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
  94. return res;
  95. }
  96. // fp8x4 -> bf16_4_t
  97. template <>
  98. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
  99. {
  100. bf16_4_t res;
  101. res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
  102. res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
  103. return res;
  104. }
  105. // fp8x8 -> bf16_8_t
  106. template <>
  107. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
  108. {
  109. bf16_4_t tmp1, tmp2;
  110. tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
  111. tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
  112. bf16_8_t res;
  113. res.x = tmp1.x;
  114. res.y = tmp1.y;
  115. res.z = tmp2.x;
  116. res.w = tmp2.y;
  117. return res;
  118. }
  119. // fp8 -> float
  120. template <>
  121. __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
  122. {
  123. hip_fp8 fp8{a, hip_fp8::from_bits()};
  124. return static_cast<float>(fp8);
  125. }
  126. // fp8x2 -> float2
  127. template <>
  128. __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
  129. {
  130. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  131. float2 res;
  132. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  133. res.x = f2[0];
  134. res.y = f2[1];
  135. return res;
  136. #else
  137. float2 res;
  138. res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
  139. res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
  140. return res;
  141. #endif
  142. }
  143. // fp8x4 -> float4
  144. template <>
  145. __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
  146. {
  147. Float4_ res;
  148. res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
  149. res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
  150. return res;
  151. }
  152. // fp8x8 -> float8
  153. template <>
  154. __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
  155. {
  156. Float4_ tmp1, tmp2;
  157. tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
  158. tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
  159. Float8_ res;
  160. res.x = tmp1.x;
  161. res.y = tmp1.y;
  162. res.z = tmp2.x;
  163. res.w = tmp2.y;
  164. return res;
  165. }
  166. // half -> fp8
  167. template <>
  168. __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
  169. {
  170. __half_raw tmp;
  171. tmp.x = a;
  172. hip_fp8 f8{static_cast<float>(tmp.data)};
  173. return f8.data;
  174. }
  175. // bf16 -> fp8
  176. template <>
  177. __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
  178. {
  179. hip_fp8 res{__bfloat162float(a)};
  180. return res.data;
  181. }
  182. // float -> fp8
  183. template <>
  184. __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
  185. {
  186. hip_fp8 f8(a);
  187. return f8.data;
  188. }
  189. // fp8x4 -> float4
  190. template <>
  191. __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
  192. {
  193. Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
  194. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  195. return res;
  196. }
  197. // float2 -> half2
  198. template <>
  199. __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
  200. {
  201. union {
  202. half2 float16;
  203. uint32_t uint32;
  204. };
  205. float16 = __float22half2_rn(a);
  206. return uint32;
  207. }
  208. // Float4 -> half2x2
  209. template <>
  210. __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
  211. {
  212. uint2 b;
  213. float2 val;
  214. val.x = a.x.x;
  215. val.y = a.x.y;
  216. b.x = vec_conversion<uint32_t, float2>(val);
  217. val.x = a.y.x;
  218. val.y = a.y.y;
  219. b.y = vec_conversion<uint32_t, float2>(val);
  220. return b;
  221. }
  222. // Float4 -> float4
  223. template <>
  224. __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
  225. {
  226. float4 b;
  227. b.x = a.x.x;
  228. b.y = a.x.y;
  229. b.z = a.y.x;
  230. b.w = a.y.y;
  231. return b;
  232. }
  233. // Float8 -> half2x4
  234. template <>
  235. __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
  236. {
  237. uint4 b;
  238. b.x = vec_conversion<uint32_t, float2>(a.x);
  239. b.y = vec_conversion<uint32_t, float2>(a.y);
  240. b.z = vec_conversion<uint32_t, float2>(a.z);
  241. b.w = vec_conversion<uint32_t, float2>(a.w);
  242. return b;
  243. }
  244. // float2 -> bfloat162
  245. template <>
  246. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2& a)
  247. {
  248. __nv_bfloat162 b = __float22bfloat162_rn(a);
  249. return b;
  250. }
  251. // Float4 -> bfloat162x2
  252. template <>
  253. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_& a)
  254. {
  255. bf16_4_t b;
  256. b.x = __float22bfloat162_rn(a.x);
  257. b.y = __float22bfloat162_rn(a.y);
  258. return b;
  259. }
  260. // Float8 -> bfloat162x4
  261. template <>
  262. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_& a)
  263. {
  264. bf16_8_t b;
  265. b.x = __float22bfloat162_rn(a.x);
  266. b.y = __float22bfloat162_rn(a.y);
  267. b.z = __float22bfloat162_rn(a.z);
  268. b.w = __float22bfloat162_rn(a.w);
  269. return b;
  270. }
  271. /* Scaled and vectorized conversions, for data exchange between high and low precision domains
  272. Convention of the scale in API, e.g: FP8_data = Quantization( High_Precision_data / scale )
  273. s.t.
  274. Quantize(HP / scale) => FP8
  275. Dequant(FP8) * scale => HP
  276. */
  277. // fp8 -> half
  278. template <>
  279. __inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale)
  280. {
  281. hip_fp8 f8{a, hip_fp8::from_bits()};
  282. __half_raw res;
  283. res.data = static_cast<float>(f8) * scale;
  284. return res.x;
  285. }
  286. // fp8x2 -> half2
  287. template <>
  288. __inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, const float scale)
  289. {
  290. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  291. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  292. union {
  293. __half2_raw h2r;
  294. uint32_t ui32;
  295. } tmp;
  296. tmp.h2r.x.data = f2[0] * scale;
  297. tmp.h2r.y.data = f2[1] * scale;
  298. return tmp.ui32;
  299. #else
  300. union {
  301. uint16_t u16[2];
  302. uint32_t u32;
  303. } tmp;
  304. tmp.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
  305. tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
  306. return tmp.u32;
  307. #endif
  308. }
  309. // fp8x4 -> half2x2
  310. template <>
  311. __inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale)
  312. {
  313. union {
  314. uint2 u32x2;
  315. uint32_t u32[2];
  316. } tmp;
  317. tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
  318. tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
  319. return tmp.u32x2;
  320. }
  321. // fp8x8 -> half2x4
  322. template <>
  323. __inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale)
  324. {
  325. union {
  326. uint4 u64x2;
  327. uint2 u64[2];
  328. } tmp;
  329. tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  330. tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  331. return tmp.u64x2;
  332. }
  333. using __nv_bfloat16 = __hip_bfloat16;
  334. // fp8 -> __nv_bfloat16
  335. template <>
  336. __inline__ __device__ __nv_bfloat16 scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, const float scale)
  337. {
  338. hip_fp8 f8{a, hip_fp8::from_bits()};
  339. float f{f8};
  340. return __float2bfloat16(f * scale);
  341. }
  342. using __nv_bfloat162 = __hip_bfloat162;
  343. // fp8x2 -> __nv_bfloat162
  344. template <>
  345. __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, const float scale)
  346. {
  347. __nv_bfloat162 res;
  348. res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  349. res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  350. return res;
  351. }
  352. // fp8x4 -> bf16_4_t
  353. template <>
  354. __inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, const float scale)
  355. {
  356. bf16_4_t res;
  357. res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  358. res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), scale);
  359. return res;
  360. }
  361. // fp8x8 -> bf16_8_t
  362. template <>
  363. __inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale)
  364. {
  365. bf16_4_t tmp1, tmp2;
  366. tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  367. tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  368. bf16_8_t res;
  369. res.x = tmp1.x;
  370. res.y = tmp1.y;
  371. res.z = tmp2.x;
  372. res.w = tmp2.y;
  373. return res;
  374. }
  375. // fp8 -> float
  376. template <>
  377. __inline__ __device__ float scaled_vec_conversion<float, uint8_t>(const uint8_t& a, const float scale)
  378. {
  379. hip_fp8 fp8{a, hip_fp8::from_bits()};
  380. return static_cast<float>(fp8) * scale;
  381. }
  382. // fp8x2 -> float2
  383. template <>
  384. __inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale)
  385. {
  386. #if defined(__HIP__MI300__) && defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  387. float2 res;
  388. const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  389. res.x = f2[0] * scale;
  390. res.y = f2[1] * scale;
  391. return res;
  392. #else
  393. float2 res;
  394. res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
  395. res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U), scale);
  396. return res;
  397. #endif
  398. }
  399. // fp8x4 -> float4
  400. template <>
  401. __inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale)
  402. {
  403. Float4_ res;
  404. res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  405. res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  406. return res;
  407. }
  408. // fp8x8 -> float8
  409. template <>
  410. __inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale)
  411. {
  412. Float4_ tmp1, tmp2;
  413. tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  414. tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  415. Float8_ res;
  416. res.x = tmp1.x;
  417. res.y = tmp1.y;
  418. res.z = tmp2.x;
  419. res.w = tmp2.y;
  420. return res;
  421. }
  422. /* Quantize(HP / scale) => FP8 */
  423. // TODO: vectorized to add
  424. // half -> fp8
  425. template <>
  426. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale)
  427. {
  428. __half_raw tmp;
  429. tmp.x = a;
  430. hip_fp8 f8{static_cast<float>(tmp.data)/scale};
  431. return f8.data;
  432. }
  433. // bf16 -> fp8
  434. template <>
  435. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a, const float scale)
  436. {
  437. hip_fp8 res{__bfloat162float(a)/scale};
  438. return res.data;
  439. }
  440. // float -> fp8
  441. template <>
  442. __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(const float& a, const float scale)
  443. {
  444. hip_fp8 f8(a/scale);
  445. return f8.data;
  446. }
  447. // fp8x4 -> float4
  448. template <>
  449. __inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale)
  450. {
  451. Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  452. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  453. return res;
  454. }
  455. }
  456. } // namespace aphrodite