1
0

quant_utils.cuh 6.6 KB


  1. #pragma once
  2. #include <assert.h>
  3. #include <stdint.h>
  4. #include <float.h>
  5. #include <type_traits>
  6. #include "../../attention/attention_dtypes.h"
  7. #include "../../attention/dtype_float32.cuh"
  8. #include "../../attention/dtype_float16.cuh"
  9. #include "../../attention/dtype_bfloat16.cuh"
  10. namespace aphrodite {
  11. #ifdef ENABLE_FP8_E5M2
  12. namespace fp8_e5m2_unscaled {
  13. template<typename Tout, typename Tin>
  14. __inline__ __device__ Tout vec_conversion(const Tin& x)
  15. {
  16. return x;
  17. }
  18. // fp8 -> half
  19. template<>
  20. __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
  21. {
  22. __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
  23. return res.x;
  24. }
  25. // fp8x2 -> half2
  26. template<>
  27. __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
  28. {
  29. union {
  30. uint16_t u16[2];
  31. uint32_t u32;
  32. } tmp;
  33. __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
  34. tmp.u16[0] = res.x;
  35. tmp.u16[1] = res.y;
  36. return tmp.u32;
  37. }
  38. // fp8x4 -> half2x2
  39. template<>
  40. __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
  41. {
  42. union {
  43. uint2 u32x2;
  44. uint32_t u32[2];
  45. } tmp;
  46. tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
  47. tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
  48. return tmp.u32x2;
  49. }
  50. // fp8x8 -> half2x4
  51. template<>
  52. __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
  53. {
  54. union {
  55. uint4 u64x2;
  56. uint2 u64[2];
  57. } tmp;
  58. tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
  59. tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
  60. return tmp.u64x2;
  61. }
  62. // fp8 -> __nv_bfloat16
  63. template<>
  64. __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
  65. {
  66. // Note there is no direct convert function from fp8 to bf16.
  67. // fp8 -> half
  68. __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
  69. // half -> float -> bf16
  70. float tmp = half_to_float(res.x);
  71. return __float2bfloat16(tmp);
  72. }
  73. // fp8x2 -> __nv_bfloat162
  74. template<>
  75. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
  76. {
  77. __nv_bfloat162 res;
  78. res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
  79. res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
  80. return res;
  81. }
  82. // fp8x4 -> bf16_4_t
  83. template<>
  84. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
  85. {
  86. bf16_4_t res;
  87. res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
  88. res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
  89. return res;
  90. }
  91. // fp8x8 -> bf16_8_t
  92. template<>
  93. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
  94. {
  95. bf16_4_t tmp1, tmp2;
  96. tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
  97. tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
  98. bf16_8_t res;
  99. res.x = tmp1.x;
  100. res.y = tmp1.y;
  101. res.z = tmp2.x;
  102. res.w = tmp2.y;
  103. return res;
  104. }
  105. // fp8 -> float
  106. template<>
  107. __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
  108. {
  109. // fp8 -> half
  110. uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
  111. // half -> float
  112. return half_to_float(tmp);
  113. }
  114. // fp8x2 -> float2
  115. template<>
  116. __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
  117. {
  118. // fp8x2 -> half2
  119. uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
  120. // half2 -> float2
  121. return half2_to_float2(tmp);
  122. }
  123. // fp8x4 -> float4
  124. template<>
  125. __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
  126. {
  127. Float4_ res;
  128. res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
  129. res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
  130. return res;
  131. }
  132. // fp8x8 -> float8
  133. template<>
  134. __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
  135. {
  136. Float4_ tmp1, tmp2;
  137. tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
  138. tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
  139. Float8_ res;
  140. res.x = tmp1.x;
  141. res.y = tmp1.y;
  142. res.z = tmp2.x;
  143. res.w = tmp2.y;
  144. return res;
  145. }
  146. // half -> fp8
  147. template<>
  148. __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
  149. {
  150. __half_raw tmp;
  151. tmp.x = a;
  152. __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
  153. return (uint8_t)res;
  154. }
  155. // bf16 -> fp8
  156. template<>
  157. __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
  158. {
  159. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  160. assert(false);
  161. #else
  162. __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
  163. return (uint8_t)res;
  164. #endif
  165. }
  166. // float -> fp8
  167. template<>
  168. __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
  169. {
  170. __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
  171. return (uint8_t)res;
  172. }
  173. // fp8x4 -> float4
  174. template<>
  175. __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
  176. {
  177. Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
  178. float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  179. return res;
  180. }
  181. template<>
  182. __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
  183. {
  184. union {
  185. half2 float16;
  186. uint32_t uint32;
  187. };
  188. float16 = __float22half2_rn(a);
  189. return uint32;
  190. }
  191. template<>
  192. __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
  193. {
  194. uint2 b;
  195. float2 val;
  196. val.x = a.x.x;
  197. val.y = a.x.y;
  198. b.x = vec_conversion<uint32_t, float2>(val);
  199. val.x = a.y.x;
  200. val.y = a.y.y;
  201. b.y = vec_conversion<uint32_t, float2>(val);
  202. return b;
  203. }
  204. template<>
  205. __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
  206. {
  207. float4 b;
  208. b.x = a.x.x;
  209. b.y = a.x.y;
  210. b.z = a.y.x;
  211. b.w = a.y.y;
  212. return b;
  213. }
  214. template<>
  215. __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
  216. {
  217. uint4 b;
  218. b.x = vec_conversion<uint32_t, float2>(a.x);
  219. b.y = vec_conversion<uint32_t, float2>(a.y);
  220. b.z = vec_conversion<uint32_t, float2>(a.z);
  221. b.w = vec_conversion<uint32_t, float2>(a.w);
  222. return b;
  223. }
  224. template<>
  225. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
  226. __nv_bfloat162 b;
  227. from_float(b, a);
  228. return b;
  229. }
  230. template<>
  231. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
  232. bf16_4_t b;
  233. from_float(b, a);
  234. return b;
  235. }
  236. template<>
  237. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
  238. bf16_8_t b;
  239. from_float(b, a);
  240. return b;
  241. }
  242. } // namespace fp8_e5m2_unscaled
  243. #endif // ENABLE_FP8_E5M2
  244. } // namespace aphrodite