1
0

common.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include <cassert>
  19. #include <cstdint>
  20. #include <cuda_bf16.h>
  21. #include <cuda_fp16.h>
  22. #include <type_traits>
  23. namespace aphrodite {
  24. namespace autoquant {
  25. #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750))
  26. #define APHRODITE_ARCH_SM75 1
  27. #else
  28. #define APHRODITE_ARCH_SM75 0
  29. #endif
  30. #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
  31. #define APHRODITE_ARCH_SM80 1
  32. #else
  33. #define APHRODITE_ARCH_SM80 0
  34. #endif
  35. constexpr int WARP_SIZE = 32;
  36. #if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__)
  37. #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__))
  38. #define PRAGMA_UNROLL _Pragma("unroll")
  39. #define PRAGMA_NO_UNROLL _Pragma("unroll 1")
  40. #else
  41. #define PRAGMA_UNROLL #pragma unroll
  42. #define PRAGMA_NO_UNROLL #pragma unroll 1
  43. #endif
  44. #else
  45. #define PRAGMA_UNROLL
  46. #define PRAGMA_NO_UNROLL
  47. #endif
  48. static const float HALF_FLT_MAX = 65504.F;
  49. // Modified from NVIDIA FasterTransformer:
  50. // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  51. // Modified from llm-awq https://github.com/mit-han-lab/llm-awq/blob/main/awq/kernels/csrc/quantization/dequantize.cuh
  52. __inline__ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
  53. {
  54. uint4 result;
  55. uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  56. uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
  57. // First, we extract the i4s and construct an intermediate fp16 number.
  58. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
  59. static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
  60. static constexpr uint32_t TOP_MASK = 0x00f000f0;
  61. static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
  62. // Note that the entire sequence only requires 1 shift instruction. This is
  63. // thanks to the register packing format and the fact that we force our
  64. // integers to be unsigned, and account for this in the fp16 subtractions. In
  65. // addition, I exploit the fact that sub and fma have the same throughput in
  66. // order to convert elt_23 and elt_67 to fp16 without having to shift them to
  67. // the bottom bits before hand.
  68. // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
  69. // dependency if we issue immediately before required.
  70. const uint32_t top_i4s = i4s >> 8;
  71. // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
  72. asm("lop3.b32 %0, %1, %2, %3, %4;\n"
  73. : "=r"(h[0])
  74. : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  75. // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
  76. asm("lop3.b32 %0, %1, %2, %3, %4;\n"
  77. : "=r"(h[1])
  78. : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  79. // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
  80. asm("lop3.b32 %0, %1, %2, %3, %4;\n"
  81. : "=r"(h[2])
  82. : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  83. // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
  84. asm("lop3.b32 %0, %1, %2, %3, %4;\n"
  85. : "=r"(h[3])
  86. : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  87. // I use inline PTX below because I am not sure if the compiler will emit
  88. // float2half instructions if I use the half2 ctor. In this case, I chose
  89. // performance reliability over code readability.
  90. // This is the half2 {1032, 1032} represented as an integer.
  91. // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
  92. // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
  93. static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
  94. // This is the half2 {1 / 16, 1 / 16} represented as an integer.
  95. static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
  96. // This is the half2 {-72, -72} represented as an integer.
  97. // static constexpr uint32_t NEG_72 = 0xd480d480;
  98. // Haotian: Let's use {-64, -64}.
  99. static constexpr uint32_t NEG_64 = 0xd400d400;
  100. // Finally, we construct the output numbers.
  101. // Convert elt_01
  102. asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
  103. // Convert elt_23
  104. asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
  105. // Convert elt_45
  106. asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
  107. // Convert elt_67
  108. asm("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
  109. return result;
  110. }
  111. __inline__ __device__ uint4 dequantize_s4_to_fp16x2_v2(uint32_t const& source)
  112. {
  113. uint4 result;
  114. uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  115. uint32_t const& i4s = reinterpret_cast<uint32_t const&>(source);
  116. // First, we extract the i4s and construct an intermediate fp16 number.
  117. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
  118. static constexpr uint32_t BOT_MASK = 0x000f000f;
  119. static constexpr uint32_t TOP_MASK = 0x00f000f0;
  120. static constexpr uint32_t MAGIC_NUM_0 = 0x64006400; // `1024`
  121. static constexpr uint32_t MAGIC_NUM_1 = 0x54005400; // `64`
  122. static constexpr uint32_t MAGIC_NUM_2 = MAGIC_NUM_1 >> 4; // `64` >> 4
  123. // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
  124. // dependency if we issue immediately before required.
  125. const uint32_t top_i4s = i4s >> 8;
  126. if (0) { // 1024 & 64
  127. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
  128. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
  129. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_0), "n"(immLut));
  130. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
  131. asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(MAGIC_NUM_0));
  132. asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(MAGIC_NUM_1));
  133. asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(MAGIC_NUM_0));
  134. asm("sub.f16x2 %0, %1, %2;\n" : "=r"(h[3]) : "r"(h[3]), "r"(MAGIC_NUM_1));
  135. }
  136. else { // 64 only, trade 4 hfma2 with 2 shifts
  137. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[0]) : "r"(i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
  138. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[1]) : "r"(i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
  139. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[2]) : "r"(top_i4s), "n"(BOT_MASK), "n"(MAGIC_NUM_2), "n"(immLut));
  140. asm("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(h[3]) : "r"(top_i4s), "n"(TOP_MASK), "n"(MAGIC_NUM_1), "n"(immLut));
  141. h[0] <<= 4;
  142. h[2] <<= 4;
  143. // we don't need to subtract the magic nums because zeros will go through the same dequant function
  144. // and carry the same magic constant, the magic num will be canceled out after subtracting zeros
  145. }
  146. return result;
  147. }
  148. __inline__ __device__ uint4 dequantize_s4_to_bf16x2_v2(uint32_t const& source)
  149. {
  150. uint4 result;
  151. uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  152. uint32_t const& source_i4s = reinterpret_cast<uint32_t const&>(source);
  153. // First, we extract the i4s and construct an intermediate fp16 number.
  154. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
  155. static constexpr uint32_t MASK = 0x000f000f;
  156. static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300;
  157. // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop.
  158. // No shift needed for first item.
  159. uint32_t i4s = source_i4s;
  160. asm ("lop3.b32 %0, %1, %2, %3, %4;\n"
  161. : "=r"(h[0])
  162. : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
  163. PRAGMA_UNROLL
  164. for (int ii = 1; ii < 4; ++ii)
  165. {
  166. i4s >>= 4;
  167. // (i4s & 0x000f000f) | 0x43004300
  168. asm("lop3.b32 %0, %1, %2, %3, %4;\n"
  169. : "=r"(h[ii])
  170. : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut));
  171. }
  172. return result;
  173. }
  174. __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
  175. {
  176. uint32_t smem_int_ptr;
  177. asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n"
  178. : "=r"(smem_int_ptr)
  179. : "l"(ptr));
  180. return smem_int_ptr;
  181. }
  182. __inline__ __device__ void ldmatrix_m8n8_x4_b16(uint& d0, uint& d1, uint& d2, uint& d3, uint32_t smem_int_ptr)
  183. {
  184. #if APHRODITE_ARCH_SM75
  185. asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
  186. : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3)
  187. : "r"(smem_int_ptr));
  188. #else
  189. assert(APHRODITE_ARCH_SM75);
  190. #endif
  191. }
  192. __inline__ __device__ void ldmatrix_m8n8_x2_b16(uint& d0, uint& d1, uint32_t smem_int_ptr)
  193. {
  194. #if APHRODITE_ARCH_SM75
  195. asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(d0), "=r"(d1) : "r"(smem_int_ptr));
  196. #else
  197. assert(APHRODITE_ARCH_SM75);
  198. #endif
  199. }
  200. __inline__ __device__ void wait_flag(int* lock, int status, int thread_id)
  201. {
  202. int state = 0;
  203. while (__syncthreads_and(state != status)) {
  204. if (thread_id == 0) {
  205. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
  206. asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
  207. #else
  208. asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
  209. #endif
  210. }
  211. }
  212. __syncthreads(); // memory fence
  213. }
  214. __inline__ __device__ void release_flag(int* lock, int status, int thread_id)
  215. {
  216. __syncthreads(); // memory fence
  217. if (thread_id == 0) {
  218. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
  219. asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
  220. #else
  221. asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
  222. #endif
  223. }
  224. }
  225. template <typename T>
  226. __inline__ __device__ T apply_Q(const T& x, const T& q);
  227. template <>
  228. __inline__ __device__ half2 apply_Q(const half2& x, const half2& q)
  229. {
  230. uint s, z;
  231. (half2&)z = __halves2half2(q.x, q.x);
  232. (half2&)s = __halves2half2(q.y, q.y);
  233. auto& t = (const uint&)x;
  234. uint u, v;
  235. asm("sub.ftz.f16x2 %0, %1, %2;\n" : "=r"(u) : "r"(t), "r"(z));
  236. asm("mul.ftz.f16x2 %0, %1, %2;\n" : "=r"(v) : "r"(u), "r"(s));
  237. return (half2&)v;
  238. }
  239. inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  240. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  241. assert(false);
  242. #else
  243. return __hsub2(a, b);
  244. #endif
  245. }
  246. inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 a, const __nv_bfloat162 b) {
  247. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  248. assert(false);
  249. #else
  250. return __hmul2(a, b);
  251. #endif
  252. }
  253. inline __device__ __nv_bfloat162 halves2bfloat162(const __nv_bfloat16 a, const __nv_bfloat16 b){
  254. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  255. assert(false);
  256. #else
  257. return __halves2bfloat162(a, b);
  258. #endif
  259. }
  260. inline __device__ float2 bfloat1622float2(const __nv_bfloat162 a){
  261. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  262. assert(false);
  263. #else
  264. return __bfloat1622float2(a);
  265. #endif
  266. }
  267. inline __device__ __nv_bfloat162 float22bfloat162_rn(const float2 a){
  268. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  269. assert(false);
  270. #else
  271. return __float22bfloat162_rn(a);
  272. #endif
  273. }
  274. template <>
  275. __inline__ __device__ __nv_bfloat162 apply_Q(const __nv_bfloat162& x, const __nv_bfloat162& q)
  276. {
  277. __nv_bfloat162 s, z;
  278. (__nv_bfloat162&)z = halves2bfloat162(q.x, q.x);
  279. (__nv_bfloat162&)s = halves2bfloat162(q.y, q.y);
  280. __nv_bfloat162 u, v;
  281. u = bf16hsub2(x, z);
  282. v = bf16hmul2(u, s);
  283. return v;
  284. }
  285. __device__ __forceinline__ float clamp_inf_for_half(const float input)
  286. {
  287. // clamp inf values to enable fp16 training
  288. return input > 0.0f ? min(input, (HALF_FLT_MAX - 1000) / 2.0) : max(input, (-HALF_FLT_MAX + 1000) / 2.0);
  289. }
  290. template<typename T, int N>
  291. struct Array {
  292. T a[N];
  293. __device__ __host__ constexpr T& operator[](int i) noexcept
  294. {
  295. return a[i];
  296. }
  297. __device__ __host__ constexpr const T& operator[](int i) const noexcept
  298. {
  299. return a[i];
  300. }
  301. };
  302. template<int... Ns>
  303. struct Shape {
  304. static constexpr Array<int, sizeof...(Ns)> data_{Ns...};
  305. constexpr Shape() = default;
  306. Shape(std::integral_constant<int, Ns>...){};
  307. template<int index>
  308. constexpr auto get() const noexcept
  309. {
  310. return std::integral_constant<int, data_[index]>{};
  311. }
  312. constexpr auto m() const noexcept
  313. {
  314. return get<0>();
  315. }
  316. constexpr auto n() const noexcept
  317. {
  318. return get<1>();
  319. }
  320. constexpr auto k() const noexcept
  321. {
  322. return get<2>();
  323. }
  324. constexpr int c() const noexcept
  325. {
  326. return get<0>();
  327. }
  328. constexpr int s() const noexcept
  329. {
  330. return get<1>();
  331. }
  332. constexpr int count() const noexcept
  333. {
  334. return (Ns * ...);
  335. }
  336. };
  337. } // namespace autoquant
  338. } // namespace aphrodite