quant_utils.cuh 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. // Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  2. #pragma once
  3. #include <assert.h>
  4. #include <stdint.h>
  5. #include <float.h>
  6. #include <type_traits>
  7. #include "../../attention/attention_dtypes.h"
  8. namespace aphrodite {
  9. namespace int8 {
  10. // float32 to int8
  11. inline __device__ int8_t quant(float a, const float scale, const float zp)
  12. {
  13. int8_t int8;
  14. int8 = round(max(-128.f, min(127.f, (a - zp) / scale)));
  15. return int8;
  16. }
  17. // float32x2 to int8x2
  18. inline __device__ short quant(float2 a, const float scale, const float zp)
  19. {
  20. union {
  21. int8_t int8[2];
  22. short int16;
  23. };
  24. int8[0] = quant(a.x, scale, zp);
  25. int8[1] = quant(a.y, scale, zp);
  26. return int16;
  27. }
  28. // float32x4 to int8x4
  29. inline __device__ int32_t quant(float4 a, const float scale, const float zp)
  30. {
  31. union {
  32. int8_t int8[4];
  33. int32_t int32;
  34. };
  35. int8[0] = quant(a.x, scale, zp);
  36. int8[1] = quant(a.y, scale, zp);
  37. int8[2] = quant(a.z, scale, zp);
  38. int8[3] = quant(a.w, scale, zp);
  39. return int32;
  40. }
  41. // float16 to int8
  42. inline __device__ int8_t quant(uint16_t a, const float scale, const float zp)
  43. {
  44. int8_t int8;
  45. float b = half_to_float(a);
  46. int8 = quant(b, scale, zp);
  47. return int8;
  48. }
  49. // float16x2 to int8x2
  50. inline __device__ int16_t quant(uint32_t a, const float scale, const float zp)
  51. {
  52. union {
  53. int8_t int8[2];
  54. short int16;
  55. };
  56. float2 b = half2_to_float2(a);
  57. int8[0] = quant(b.x, scale, zp);
  58. int8[1] = quant(b.y, scale, zp);
  59. return int16;
  60. }
  61. // float16x4 to int8x4
  62. inline __device__ int32_t quant(uint2 a, const float scale, const float zp)
  63. {
  64. union {
  65. int16_t int16[2];
  66. int32_t int32;
  67. };
  68. int16[0] = quant(a.x, scale, zp);
  69. int16[1] = quant(a.y, scale, zp);
  70. return int32;
  71. }
  72. // float16x8 to int8x8
  73. inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
  74. {
  75. union {
  76. int16_t int16[4];
  77. int64_t int64;
  78. };
  79. int16[0] = quant(a.x, scale, zp);
  80. int16[1] = quant(a.y, scale, zp);
  81. int16[2] = quant(a.z, scale, zp);
  82. int16[3] = quant(a.w, scale, zp);
  83. return int64;
  84. }
  85. // bf16 to int8
  86. inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp)
  87. {
  88. int8_t int8;
  89. float b = to_float(a);
  90. int8 = quant(b, scale, zp);
  91. return int8;
  92. }
  93. //bf16x2 to int8x2
  94. inline __device__ int16_t quant(__nv_bfloat162 a, const float scale, const float zp)
  95. {
  96. union {
  97. int8_t int8[2];
  98. short int16;
  99. };
  100. float2 b = bf1622float2(a);
  101. int8[0] = quant(b.x, scale, zp);
  102. int8[1] = quant(b.y, scale, zp);
  103. return int16;
  104. }
  105. // bf16x4 to int8x4
  106. inline __device__ int32_t quant(bf16_4_t a, const float scale, const float zp)
  107. {
  108. union {
  109. int16_t int16[2];
  110. int32_t int32;
  111. };
  112. int16[0] = quant(a.x, scale, zp);
  113. int16[1] = quant(a.y, scale, zp);
  114. return int32;
  115. }
  116. // bf16x8 to int8x8
  117. inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
  118. {
  119. union {
  120. int16_t int16[4];
  121. int64_t int64;
  122. };
  123. int16[0] = quant(a.x, scale, zp);
  124. int16[1] = quant(a.y, scale, zp);
  125. int16[2] = quant(a.z, scale, zp);
  126. int16[3] = quant(a.w, scale, zp);
  127. return int64;
  128. }
  129. // int8 to float32, then `vec_conversion` to target format
  130. inline __device__ float dequant(int8_t a, const float scale, const float zp)
  131. {
  132. float b = a * scale + zp;
  133. return b;
  134. }
  135. // int8x2 to float32x2
  136. inline __device__ float2 dequant(int16_t a, const float scale, const float zp)
  137. {
  138. union {
  139. int8_t int8[2];
  140. int16_t int16;
  141. };
  142. int16 = a;
  143. float2 b;
  144. b.x = int8[0] * scale + zp;
  145. b.y = int8[1] * scale + zp;
  146. return b;
  147. }
  148. // int8x4 to float32x4
  149. inline __device__ Float4_ dequant(int32_t a, const float scale, const float zp)
  150. {
  151. union {
  152. int8_t int8[4];
  153. int32_t int32;
  154. };
  155. int32 = a;
  156. Float4_ b;
  157. b.x.x = (int8[0] * scale) + zp;
  158. b.x.y = (int8[1] * scale) + zp;
  159. b.y.x = (int8[2] * scale) + zp;
  160. b.y.y = (int8[3] * scale) + zp;
  161. return b;
  162. }
  163. // int8x8 to float32x8
  164. inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp)
  165. {
  166. union {
  167. int16_t int16[4];
  168. int64_t int64;
  169. };
  170. int64 = a;
  171. Float8_ b;
  172. b.x = dequant(int16[0], scale, zp);
  173. b.y = dequant(int16[1], scale, zp);
  174. b.z = dequant(int16[2], scale, zp);
  175. b.w = dequant(int16[3], scale, zp);
  176. return b;
  177. }
  178. template<typename Tout, typename Tin>
  179. __inline__ __device__ Tout vec_conversion(const Tin& x)
  180. {
  181. return x;
  182. }
  183. template<>
  184. __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
  185. {
  186. union {
  187. half2 float16;
  188. uint32_t uint32;
  189. };
  190. float16 = __float22half2_rn(a);
  191. return uint32;
  192. }
  193. template<>
  194. __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
  195. {
  196. uint2 b;
  197. float2 val;
  198. val.x = a.x.x;
  199. val.y = a.x.y;
  200. b.x = vec_conversion<uint32_t, float2>(val);
  201. val.x = a.y.x;
  202. val.y = a.y.y;
  203. b.y = vec_conversion<uint32_t, float2>(val);
  204. return b;
  205. }
  206. template<>
  207. __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
  208. {
  209. float4 b;
  210. b.x = a.x.x;
  211. b.y = a.x.y;
  212. b.z = a.y.x;
  213. b.w = a.y.y;
  214. return b;
  215. }
  216. template<>
  217. __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
  218. {
  219. uint4 b;
  220. b.x = vec_conversion<uint32_t, float2>(a.x);
  221. b.y = vec_conversion<uint32_t, float2>(a.y);
  222. b.z = vec_conversion<uint32_t, float2>(a.z);
  223. b.w = vec_conversion<uint32_t, float2>(a.w);
  224. return b;
  225. }
  226. template<>
  227. __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
  228. __nv_bfloat162 b;
  229. from_float(b, a);
  230. return b;
  231. }
  232. template<>
  233. __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
  234. bf16_4_t b;
  235. from_float(b, a);
  236. return b;
  237. }
  238. template<>
  239. __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
  240. bf16_8_t b;
  241. from_float(b, a);
  242. return b;
  243. }
  244. } // namespace int8
  245. } // namespace aphrodite