1
0

dequantize.cuh 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /*
  2. Adapted from https://github.com/mit-han-lab/llm-awq
  3. Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
  4. @article{lin2023awq,
  5. title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  6. author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  7. journal={arXiv},
  8. year={2023}
  9. }
  10. */
  11. #pragma once
  12. namespace aphrodite {
  13. namespace awq {
  14. __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
  15. {
  16. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
  17. assert(false);
  18. #else
  19. uint4 result;
  20. uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  21. uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
  22. // First, we extract the i4s and construct an intermediate fp16 number.
  23. static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
  24. static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
  25. static constexpr uint32_t TOP_MASK = 0x00f000f0;
  26. static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
  27. // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
  28. // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
  29. // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
  30. // elt_67 to fp16 without having to shift them to the bottom bits before hand.
  31. // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
  32. // immediately before required.
  33. const uint32_t top_i4s = i4s >> 8;
  34. // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
  35. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  36. : "=r"(h[0])
  37. : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  38. // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
  39. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  40. : "=r"(h[1])
  41. : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  42. // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
  43. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  44. : "=r"(h[2])
  45. : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  46. // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
  47. asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
  48. : "=r"(h[3])
  49. : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  50. // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
  51. // half2 ctor. In this case, I chose performance reliability over code readability.
  52. // This is the half2 {1032, 1032} represented as an integer.
  53. // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
  54. // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
  55. static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
  56. // This is the half2 {1 / 16, 1 / 16} represented as an integer.
  57. static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
  58. // This is the half2 {-72, -72} represented as an integer.
  59. // static constexpr uint32_t NEG_72 = 0xd480d480;
  60. // Haotian: Let's use {-64, -64}.
  61. static constexpr uint32_t NEG_64 = 0xd400d400;
  62. // Finally, we construct the output numbers.
  63. // Convert elt_01
  64. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
  65. // Convert elt_23
  66. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
  67. // Convert elt_45
  68. asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
  69. // Convert elt_67
  70. asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
  71. return result;
  72. #endif
  73. }
  74. } // namespace awq
  75. } // namespace aphrodite